from typing import Optional, Tuple
import copy
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
from fsrl.utils import DummyLogger, WandbLogger
from torch.distributions.beta import Beta
from torch.nn import functional as F  # noqa
from tqdm.auto import trange  # noqa
from torch.distributions.normal import Normal

from osrl.common.net import DiagGaussianActor, TransformerBlock, mlp


class CDT(nn.Module):
    """
    Constrained Decision Transformer (CDT)
    
    Args:
        state_dim (int): dimension of the state space.
        action_dim (int): dimension of the action space.
        max_action (float): Maximum action value.
        seq_len (int): The length of the sequence to process.
        episode_len (int): The length of the episode.
        embedding_dim (int): The dimension of the embeddings.
        num_layers (int): The number of transformer layers to use.
        num_heads (int): The number of heads to use in the multi-head attention.
        attention_dropout (float): The dropout probability for attention layers.
        residual_dropout (float): The dropout probability for residual layers.
        embedding_dropout (float): The dropout probability for embedding layers.
        time_emb (bool): Whether to include time embeddings.
        use_rew (bool): Whether to include return embeddings.
        use_cost (bool): Whether to include cost embeddings.
        cost_transform (bool): Whether to transform the cost values.
        add_cost_feat (bool): Whether to add cost features.
        mul_cost_feat (bool): Whether to multiply cost features.
        cat_cost_feat (bool): Whether to concatenate cost features.
        action_head_layers (int): The number of layers in the action head.
        cost_prefix (bool): Whether to include a cost prefix.
        stochastic (bool): Whether to use stochastic actions.
        init_temperature (float): The initial temperature value for stochastic actions.
        target_entropy (float): The target entropy value for stochastic actions.
    """

    def __init__(
        self,
        state_dim: int,
        action_dim: int,
        max_action: float,
        seq_len: int = 10,
        episode_len: int = 1000,
        embedding_dim: int = 128,
        num_layers: int = 4,
        num_heads: int = 8,
        attention_dropout: float = 0.0,
        residual_dropout: float = 0.0,
        embedding_dropout: float = 0.0,
        time_emb: bool = True,
        use_rew: bool = False,
        use_cost: bool = False,
        cost_transform: bool = False,
        add_cost_feat: bool = False,
        mul_cost_feat: bool = False,
        cat_cost_feat: bool = False,
        action_head_layers: int = 1,
        cost_prefix: bool = False,
        stochastic: bool = False,
        init_temperature = 0.1,
        target_entropy = None,
        use_prompt: bool = False,
        prompt_prefix: bool = False,
        prompt_concat: bool = False,
        prompt_dim: int = 16
    ):
        super().__init__()
        self.seq_len = seq_len
        self.embedding_dim = embedding_dim
        self.state_dim = state_dim
        self.action_dim = action_dim
        self.episode_len = episode_len
        self.max_action = max_action
        if cost_transform:
            self.cost_transform = lambda x: 50 - x
        else:
            self.cost_transform = None
        self.add_cost_feat = add_cost_feat
        self.mul_cost_feat = mul_cost_feat
        self.cat_cost_feat = cat_cost_feat
        self.stochastic = stochastic

        self.emb_drop = nn.Dropout(embedding_dropout)
        self.emb_norm = nn.LayerNorm(embedding_dim)

        self.out_norm = nn.LayerNorm(embedding_dim)
        # additional seq_len embeddings for padding timesteps
        self.time_emb = time_emb
        if self.time_emb:
            self.timestep_emb = nn.Embedding(episode_len + seq_len, embedding_dim)

        self.state_emb = nn.Linear(state_dim, embedding_dim)
        self.action_emb = nn.Linear(action_dim, embedding_dim)

        self.seq_repeat = 2
        self.use_rew = use_rew
        self.use_cost = use_cost
        self.use_prompt = use_prompt
        if self.use_cost:
            self.cost_emb = nn.Linear(1, embedding_dim)
            self.seq_repeat += 1
        if self.use_rew:
            self.return_emb = nn.Linear(1, embedding_dim)
            self.seq_repeat += 1

        dt_seq_len = self.seq_repeat * seq_len

        self.prompt_prefix = prompt_prefix
        self.prompt_concat = prompt_concat
        if self.prompt_concat:
            self.use_prompt = False
        self.prompt_dim = prompt_dim
        if self.use_prompt:
            self.prompt_emb = nn.Linear(prompt_dim, embedding_dim)
            if not self.prompt_prefix:
                self.seq_repeat += 1
                dt_seq_len += seq_len
            else:
                dt_seq_len += 1

        self.cost_prefix = cost_prefix
        if self.cost_prefix:
            self.prefix_emb = nn.Linear(1, embedding_dim)
            dt_seq_len += 1

        self.blocks = nn.ModuleList([
            TransformerBlock(
                seq_len=dt_seq_len,
                embedding_dim=embedding_dim,
                num_heads=num_heads,
                attention_dropout=attention_dropout,
                residual_dropout=residual_dropout,
            ) for _ in range(num_layers)
        ])

        action_emb_dim = 2 * embedding_dim if self.cat_cost_feat else embedding_dim

        if self.stochastic:
            if action_head_layers >= 2:
                self.action_head = nn.Sequential(
                    nn.Linear(action_emb_dim, action_emb_dim), nn.GELU(),
                    DiagGaussianActor(action_emb_dim, action_dim))
            else:
                self.action_head = DiagGaussianActor(action_emb_dim, action_dim)
        else:
            self.action_head = mlp([action_emb_dim] * action_head_layers + [action_dim],
                                   activation=nn.GELU,
                                   output_activation=nn.Identity)
        self.state_pred_head = nn.Linear(embedding_dim, state_dim)
        # a classification problem
        self.cost_pred_head = nn.Linear(embedding_dim, 2)

        if self.stochastic:
            self.log_temperature = torch.tensor(np.log(init_temperature))
            self.log_temperature.requires_grad = True
            self.target_entropy = target_entropy

        self.apply(self._init_weights)

    def temperature(self):
        if self.stochastic:
            return self.log_temperature.exp()
        else:
            return None

    @staticmethod
    def _init_weights(module: nn.Module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.LayerNorm):
            torch.nn.init.zeros_(module.bias)
            torch.nn.init.ones_(module.weight)

    def forward(
            self,
            states: torch.Tensor,  # [batch_size, seq_len, state_dim]
            actions: torch.Tensor,  # [batch_size, seq_len, action_dim]
            returns_to_go: torch.Tensor,  # [batch_size, seq_len]
            costs_to_go: torch.Tensor,  # [batch_size, seq_len]
            time_steps: torch.Tensor,  # [batch_size, seq_len]
            prompt: torch.Tensor = None, #[batch_size, seq_len, prompt_dim] or [batch_size, 1, prompt_dim]
            padding_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
            episode_cost: torch.Tensor = None,  # [batch_size, ]
            return_mu_std: bool = False
    ) -> torch.FloatTensor:
        batch_size, seq_len = states.shape[0], states.shape[1]
        # [batch_size, seq_len, emb_dim]
        if self.time_emb:
            timestep_emb = self.timestep_emb(time_steps)
        else:
            timestep_emb = 0.0
        state_emb = self.state_emb(states) + timestep_emb
        act_emb = self.action_emb(actions) + timestep_emb

        seq_list = [state_emb, act_emb]

        if self.cost_transform is not None:
            costs_to_go = self.cost_transform(costs_to_go.detach())

        if self.use_prompt and not self.prompt_prefix:
            prompt_emb = self.prompt_emb(prompt) + timestep_emb
            seq_list.insert(0, prompt_emb)

        if self.use_cost:
            costs_emb = self.cost_emb(costs_to_go.unsqueeze(-1)) + timestep_emb
            seq_list.insert(0, costs_emb)
        if self.use_rew:
            returns_emb = self.return_emb(returns_to_go.unsqueeze(-1)) + timestep_emb
            seq_list.insert(0, returns_emb)

        # [batch_size, seq_len, 2-5, emb_dim], (c_0 s_0, a_0, c_1, s_1, a_1, ...)
        sequence = torch.stack(seq_list, dim=1).permute(0, 2, 1, 3)
        sequence = sequence.reshape(batch_size, self.seq_repeat * seq_len,
                                    self.embedding_dim)

        if padding_mask is not None:
            # [batch_size, seq_len * self.seq_repeat], stack mask identically to fit the sequence
            padding_mask = torch.stack([padding_mask] * self.seq_repeat,
                                       dim=1).permute(0, 2, 1).reshape(batch_size, -1)

        if self.cost_prefix:
            episode_cost = episode_cost.unsqueeze(-1).unsqueeze(-1)

            episode_cost = episode_cost.to(states.dtype)
            # [batch, 1, emb_dim]
            episode_cost_emb = self.prefix_emb(episode_cost)
            # [batch, 1+seq_len * self.seq_repeat, emb_dim]
            sequence = torch.cat([episode_cost_emb, sequence], dim=1)
            if padding_mask is not None:
                # [batch_size, 1+ seq_len * self.seq_repeat]
                padding_mask = torch.cat([padding_mask[:, :1], padding_mask], dim=1)
        if self.use_prompt and self.prompt_prefix:
            prompt = prompt.to(states.dtype)
            prompt_emb = self.prompt_emb(prompt)
            sequence = torch.cat([prompt_emb, sequence], dim=1)
            if padding_mask is not None:
                padding_mask = torch.cat([padding_mask[:, :1], padding_mask], dim=1)

        # LayerNorm and Dropout (!!!) as in original implementation,
        # while minGPT & huggingface uses only embedding dropout
        out = self.emb_norm(sequence)
        out = self.emb_drop(out)

        for block in self.blocks:
            out = block(out, padding_mask=padding_mask)

        # [batch_size, seq_len * self.seq_repeat, embedding_dim]
        out = self.out_norm(out)
        if self.cost_prefix and self.use_prompt and self.prompt_prefix:
            # [batch_size, seq_len * seq_repeat, embedding_dim]
            out = out[:, 2:]
        elif self.cost_prefix or (self.use_prompt and self.prompt_prefix):
            out = out[:, 1:]

        # [batch_size, seq_len, self.seq_repeat, embedding_dim]
        out = out.reshape(batch_size, seq_len, self.seq_repeat, self.embedding_dim)
        # [batch_size, self.seq_repeat, seq_len, embedding_dim]
        out = out.permute(0, 2, 1, 3)

        # [batch_size, seq_len, embedding_dim]
        action_feature = out[:, self.seq_repeat - 1]
        state_feat = out[:, self.seq_repeat - 2]

        if self.add_cost_feat and self.use_cost:
            state_feat = state_feat + costs_emb.detach()
        if self.mul_cost_feat and self.use_cost:
            state_feat = state_feat * costs_emb.detach()
        if self.cat_cost_feat and self.use_cost:
            # cost_prefix feature, deprecated
            # episode_cost_emb = episode_cost_emb.repeat_interleave(seq_len, dim=1)
            # [batch_size, seq_len, 2 * embedding_dim]
            state_feat = torch.cat([state_feat, costs_emb.detach()], dim=2)

        # get predictions
        if not self.stochastic or not return_mu_std:
            action_preds = self.action_head(
                state_feat
            )  # predict next action given state, [batch_size, seq_len, action_dim]
        else:
            mu, log_std = self.action_head(
                state_feat,
                return_mu_std = True
            )  # predict next action given state, [batch_size, seq_len, action_dim]
        # [batch_size, seq_len, 2]
        cost_preds = self.cost_pred_head(
            action_feature)  # predict next cost return given state and action
        cost_preds = F.log_softmax(cost_preds, dim=-1)

        state_preds = self.state_pred_head(
            action_feature)  # predict next state given state and action
        
        if self.stochastic and return_mu_std:
            return mu, log_std, cost_preds, state_preds

        return action_preds, cost_preds, state_preds

class CDT_with_action_AE(nn.Module):
    def __init__(self,
                 CDT,
                 action_encoder,
                 action_dim):

        super().__init__()
        self.cdt = CDT
        self.action_ae = action_encoder
        self.action_dim = action_dim
        self.stochastic = self.cdt.stochastic
        self.max_action = self.cdt.max_action
        self.state_dim = self.cdt.state_dim
        self.episode_len = self.cdt.episode_len
        self.seq_len = self.cdt.seq_len
        # self.log_temperature = self.cdt.log_temperature
        if self.stochastic:
            self.target_entropy = self.cdt.target_entropy
    
    def temperature(self):
        if self.stochastic:
            return self.cdt.log_temperature.exp()
        else:
            return None

    def forward(self,
            states: torch.Tensor,  # [batch_size, seq_len, state_dim]
            actions: torch.Tensor,  # [batch_size, seq_len, action_dim]
            returns_to_go: torch.Tensor,  # [batch_size, seq_len]
            costs_to_go: torch.Tensor,  # [batch_size, seq_len]
            time_steps: torch.Tensor,  # [batch_size, seq_len]
            padding_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
            episode_cost: torch.Tensor = None,  # [batch_size, ]
            ):
        actions = self.action_ae.encode(actions)
        if self.stochastic:
            mu, log_std, cost_preds, state_preds = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, padding_mask, episode_cost, return_mu_std=True)
            mu_dec = self.action_ae.decode(mu)
            log_std_dec = self.action_ae.decode(log_std)
            std_dec = log_std_dec.exp()
            action_preds = Normal(mu_dec, std_dec)
            return action_preds, cost_preds, state_preds
        action_preds, cost_preds, state_preds, _ = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, padding_mask, episode_cost)
        action_preds = self.action_ae.decode(action_preds)
        return action_preds, cost_preds, state_preds

class MTCDT(nn.Module):
    def __init__(self,
                 CDT,
                 action_encoder_ls,
                 state_encoder_ls):

        super().__init__()
        self.cdt = CDT
        self.action_ae_ls = nn.ModuleList(action_encoder_ls)
        self.state_ae_ls = nn.ModuleList(state_encoder_ls)
        self.stochastic = self.cdt.stochastic
        self.max_action = self.cdt.max_action
        self.seq_len = self.cdt.seq_len
        self.use_prompt = self.cdt.use_prompt
        self.prompt_prefix = self.cdt.prompt_prefix
        self.prompt_concat = self.cdt.prompt_concat
        self.prompt_dim = self.cdt.prompt_dim
        # self.log_temperature = self.cdt.log_temperature
        if self.stochastic:
            self.target_entropy = self.cdt.target_entropy
    
    def temperature(self):
        if self.stochastic:
            return self.cdt.log_temperature.exp()
        else:
            return None

    def forward(self,
            states: torch.Tensor,  # [batch_size, seq_len, state_dim]
            actions: torch.Tensor,  # [batch_size, seq_len, action_dim]
            returns_to_go: torch.Tensor,  # [batch_size, seq_len]
            costs_to_go: torch.Tensor,  # [batch_size, seq_len]
            time_steps: torch.Tensor,  # [batch_size, seq_len]
            task_id: int,
            prompt: torch.Tensor = None,
            padding_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
            episode_cost: torch.Tensor = None,  # [batch_size, ]
            ):
        actions = self.action_ae_ls[task_id].encode(actions)
        states = self.state_ae_ls[task_id].encode(states)
        state_org_dim = states.shape[-1]
        prompt_dim = prompt.shape[-1]
        if self.prompt_concat:
            states = torch.cat([states, prompt],dim=-1)
        if self.stochastic:
            mu, log_std, cost_preds, state_preds = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, prompt, padding_mask, episode_cost, return_mu_std=True)
            mu_dec, log_std_dec = self.action_ae_ls[task_id].decode(mu, log_std)
            # log_std_dec = self.action_ae_ls[task_id].decode(log_std)
            std_dec = log_std_dec.exp()
            action_preds = Normal(mu_dec, std_dec)
            if self.prompt_concat:
                org_state_preds = self.state_ae_ls[task_id].decode(state_preds[...,:state_org_dim])
                prompt_state_preds = state_preds[...,state_org_dim:]
                state_preds = torch.cat([org_state_preds,prompt_state_preds],dim=-1)
            else:
                state_preds = self.state_ae_ls[task_id].decode(state_preds)
            return action_preds, cost_preds, state_preds
        action_preds, cost_preds, state_preds, _ = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, prompt, padding_mask, episode_cost)
        action_preds = self.action_ae_ls[task_id].decode(action_preds)
        if self.prompt_concat:
            org_state_preds = self.state_ae_ls[task_id].decode(state_preds[...,:state_org_dim])
            prompt_state_preds = state_preds[...,state_org_dim:]
            state_preds = torch.cat([org_state_preds,prompt_state_preds],dim=-1)
        else:
            state_preds = self.state_ae_ls[task_id].decode(state_preds)
        return action_preds, cost_preds, state_preds


class PromptCDT(nn.Module):
    def __init__(self,
                 CDT,
                 action_encoder_ls,
                 state_encoder_ls,
                 device="cpu"
                 ):

        super().__init__()
        self.cdt = CDT
        self.action_ae_ls = nn.ModuleList(action_encoder_ls)
        self.state_ae_ls = nn.ModuleList(state_encoder_ls)
        self.stochastic = self.cdt.stochastic
        self.max_action = self.cdt.max_action
        self.seq_len = self.cdt.seq_len
        self.use_prompt = self.cdt.use_prompt
        self.prompt_prefix = self.cdt.prompt_prefix
        self.prompt_concat = self.cdt.prompt_concat
        self.prompt_dim = self.cdt.prompt_dim
        self.device = torch.device(device)
        # self.log_temperature = self.cdt.log_temperature
        if self.stochastic:
            self.target_entropy = self.cdt.target_entropy
    
    def temperature(self):
        if self.stochastic:
            return self.cdt.log_temperature.exp()
        else:
            return None

    def forward(self,
            states: torch.Tensor,  # [batch_size, seq_len, state_dim]
            actions: torch.Tensor,  # [batch_size, seq_len, action_dim]
            returns_to_go: torch.Tensor,  # [batch_size, seq_len]
            costs_to_go: torch.Tensor,  # [batch_size, seq_len]
            time_steps: torch.Tensor,  # [batch_size, seq_len]
            task_id: int,
            prompt_states: torch.Tensor,
            prompt_actions: torch.Tensor,
            prompt_returns_to_go: torch.Tensor,
            prompt_costs_to_go: torch.Tensor,
            padding_mask: Optional[torch.Tensor] = None,  # [batch_size, seq_len]
            episode_cost: torch.Tensor = None,  # [batch_size, ]
            ):
        prompt_length = prompt_states.shape[1]
        states = torch.cat([prompt_states, states], dim=1)
        actions = torch.cat([prompt_actions, actions], dim=1)
        returns_to_go = torch.cat([prompt_returns_to_go, returns_to_go], dim=1)
        costs_to_go = torch.cat([prompt_costs_to_go, costs_to_go], dim=1)
        add_time_steps = torch.zeros(states.shape[0], prompt_length, dtype=time_steps.dtype).to(self.device)
        time_steps = torch.cat([add_time_steps, time_steps], dim=1)
        # print(padding_mask.shape)
        if padding_mask is not None:
            padding_mask = torch.cat([padding_mask[:,:1].repeat(1, prompt_length), padding_mask], dim=1)
        # print(padding_mask.shape)

        actions = self.action_ae_ls[task_id].encode(actions)
        states = self.state_ae_ls[task_id].encode(states)
        state_org_dim = states.shape[-1]
        # prompt_dim = prompt.shape[-1]
        # if self.prompt_concat:
        #     states = torch.cat([states, prompt],dim=-1)
        if self.stochastic:
            mu, log_std, cost_preds, state_preds = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, None, padding_mask, episode_cost, return_mu_std=True)
            mu_dec, log_std_dec = self.action_ae_ls[task_id].decode(mu, log_std)
            # log_std_dec = self.action_ae_ls[task_id].decode(log_std)
            std_dec = log_std_dec.exp()
            mu_dec = mu_dec[:,prompt_length:,:]
            std_dec = std_dec[:,prompt_length:,:]
            action_preds = Normal(mu_dec, std_dec)
            if self.prompt_concat:
                org_state_preds = self.state_ae_ls[task_id].decode(state_preds[...,:state_org_dim])
                prompt_state_preds = state_preds[...,state_org_dim:]
                state_preds = torch.cat([org_state_preds,prompt_state_preds],dim=-1)
            else:
                state_preds = self.state_ae_ls[task_id].decode(state_preds)
            return action_preds, cost_preds[:,prompt_length:,:], state_preds[:,prompt_length:,:]
        action_preds, cost_preds, state_preds, _ = self.cdt(states, actions, returns_to_go, costs_to_go, time_steps, None, padding_mask, episode_cost)
        action_preds = self.action_ae_ls[task_id].decode(action_preds)
        if self.prompt_concat:
            org_state_preds = self.state_ae_ls[task_id].decode(state_preds[...,:state_org_dim])
            prompt_state_preds = state_preds[...,state_org_dim:]
            state_preds = torch.cat([org_state_preds,prompt_state_preds],dim=-1)
        else:
            state_preds = self.state_ae_ls[task_id].decode(state_preds)
        return action_preds[:,prompt_length:,:], cost_preds[:,prompt_length:,:], state_preds[:,prompt_length:,:]


class CDTTrainer:
    """
    Constrained Decision Transformer Trainer
    
    Args:
        model (CDT): A CDT model to train.
        env (gym.Env): The OpenAI Gym environment to train the model in.
        logger (WandbLogger or DummyLogger): The logger to use for tracking training progress.
        learning_rate (float): The learning rate for the optimizer.
        weight_decay (float): The weight decay for the optimizer.
        betas (Tuple[float, ...]): The betas for the optimizer.
        clip_grad (float): The clip gradient value.
        lr_warmup_steps (int): The number of warmup steps for the learning rate scheduler.
        reward_scale (float): The scaling factor for the reward signal.
        cost_scale (float): The scaling factor for the constraint cost.
        loss_cost_weight (float): The weight for the cost loss.
        loss_state_weight (float): The weight for the state loss.
        cost_reverse (bool): Whether to reverse the cost.
        no_entropy (bool): Whether to use entropy.
        device (str): The device to use for training (e.g. "cpu" or "cuda").

    """

    def __init__(
            self,
            model,
            env: gym.Env,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            weight_decay: float = 1e-4,
            betas: Tuple[float, ...] = (0.9, 0.999),
            clip_grad: float = 0.25,
            lr_warmup_steps: int = 10000,
            reward_scale: float = 1.0,
            cost_scale: float = 1.0,
            loss_cost_weight: float = 0.0,
            loss_state_weight: float = 0.0,
            cost_reverse: bool = False,
            no_entropy: bool = False,
            device="cpu",
            rtg_model = None,
            rtg_sample_num: int = 1000,
            rtg_sample_quantile: float = 0.8,
            rtg_sample_quantile_end: float = 0.5,
            rtg_update_every_step: bool = False,
            state_encoder = None,
            action_encoder = None,
            train_action_encoder: bool = True) -> None:
        self.model = model
        self.logger = logger
        self.env = env
        self.clip_grad = clip_grad
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.device = device
        self.cost_weight = loss_cost_weight
        self.state_weight = loss_state_weight
        self.cost_reverse = cost_reverse
        self.no_entropy = no_entropy

        self.state_encoder = state_encoder
        self.action_encoder = action_encoder
        self.train_action_encoder = train_action_encoder
        # self.action_encoder_logger = action_encoder_logger

        self.rtg_model = rtg_model
        if self.rtg_model is not None:
            self.rtg_model.eval()
        self.rtg_sample_num = rtg_sample_num
        self.rtg_sample_quantile = rtg_sample_quantile
        self.rtg_sample_quantile_start = rtg_sample_quantile
        self.rtg_sample_quantile_end = rtg_sample_quantile_end
        self.rtg_update_every_step = rtg_update_every_step
        self.optim = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=betas,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optim,
            lambda steps: min((steps + 1) / lr_warmup_steps, 1),
        )
        self.stochastic = self.model.stochastic
        if self.stochastic:
            if self.train_action_encoder and self.action_encoder is not None:
                self.log_temperature_optimizer = torch.optim.Adam(
                    [self.model.cdt.log_temperature],
                    lr=1e-4,
                    betas=[0.9, 0.999],
                )
            else:
                self.log_temperature_optimizer = torch.optim.Adam(
                    [self.model.log_temperature],
                    lr=1e-4,
                    betas=[0.9, 0.999],
                )
        self.max_action = self.model.max_action

        self.beta_dist = Beta(torch.tensor(2, dtype=torch.float, device=self.device),
                              torch.tensor(5, dtype=torch.float, device=self.device))

    def train_one_step(self, states, actions, returns, costs_return, time_steps, mask,
                       episode_cost, costs):
        # True value indicates that the corresponding key value will be ignored
        # if self.state_encoder is not None:
        #     with torch.no_grad():
        #         states = self.state_encoder.encode(states)
        # if self.action_encoder is not None:
        #     with torch.no_grad():
        #         actions = self.action_encoder.encode(actions)
        padding_mask = ~mask.to(torch.bool)

        # if self.train_action_encoder and self.action_encoder is not None:
        #     org_actions = copy.deepcopy(actions)
        #     # input_actions = self.action_encoder.encode(actions)
        #     actions = self.action_encoder.encode(actions)

        action_preds, cost_preds, state_preds = self.model(
            states=states,
            actions=actions,
            returns_to_go=returns,
            costs_to_go=costs_return,
            time_steps=time_steps,
            padding_mask=padding_mask,
            episode_cost=episode_cost,
        )

        if self.stochastic:
            log_likelihood = action_preds.log_prob(actions)[mask > 0].mean()
            entropy = action_preds.entropy()[mask > 0].mean()
            entropy_reg = self.model.temperature().detach()
            entropy_reg_item = entropy_reg.item()
            if self.no_entropy:
                entropy_reg = 0.0
                entropy_reg_item = 0.0
            act_loss = -(log_likelihood + entropy_reg * entropy)
            self.logger.store(tab="train",
                            nll=-log_likelihood.item(),
                            ent=entropy.item(),
                            ent_reg=entropy_reg_item)
        else:
            # if self.train_action_encoder and self.action_encoder is not None:
            #     action_preds = self.action_encoder.decode(action_preds)
            #     act_loss = F.mse_loss(action_preds, org_actions, reduction="none")
            # else:
            act_loss = F.mse_loss(action_preds, actions.detach(), reduction="none")
            # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
            act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        # cost_preds: [batch_size * seq_len, 2], costs: [batch_size * seq_len]
        cost_preds = cost_preds.reshape(-1, 2)
        costs = costs.flatten().long().detach()
        cost_loss = F.nll_loss(cost_preds, costs, reduction="none")
        # cost_loss = F.mse_loss(cost_preds, costs.detach(), reduction="none")
        cost_loss = (cost_loss * mask.flatten()).mean()
        # compute the accuracy, 0 value, 1 indice, [batch_size, seq_len]
        pred = cost_preds.data.max(dim=1)[1]
        correct = pred.eq(costs.data.view_as(pred)) * mask.flatten()
        correct = correct.sum()
        total_num = mask.sum()
        acc = correct / total_num

        # [batch_size, seq_len, state_dim]
        state_loss = F.mse_loss(state_preds[:, :-1],
                                states[:, 1:].detach(),
                                reduction="none")
        state_loss = (state_loss * mask[:, :-1].unsqueeze(-1)).mean()

        loss = act_loss + self.cost_weight * cost_loss + self.state_weight * state_loss

        self.optim.zero_grad()
        loss.backward()
        if self.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            # if self.action_encoder is not None and self.train_action_encoder:
            #     torch.nn.utils.clip_grad_norm_(self.action_encoder.parameters(), self.clip_grad)
        self.optim.step()

        if self.stochastic:
            self.log_temperature_optimizer.zero_grad()
            temperature_loss = (self.model.temperature() *
                                (entropy - self.model.target_entropy).detach())
            temperature_loss.backward()
            self.log_temperature_optimizer.step()

        self.scheduler.step()
        self.logger.store(
            tab="train",
            all_loss=loss.item(),
            act_loss=act_loss.item(),
            cost_loss=cost_loss.item(),
            cost_acc=acc.item(),
            state_loss=state_loss.item(),
            train_lr=self.scheduler.get_last_lr()[0],
        )

    def evaluate(self, num_rollouts, target_return, target_cost, keep_ctg_positive=False, conservative=False, need_rescale=False, deterministic=False, return_target_return=False):
        """
        Evaluates the performance of the model on a number of episodes.
        """
        self.model.eval()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.eval()
        episode_rets, episode_costs, episode_lens = [], [], []
        if return_target_return:
            episode_target_returns = []
        for _ in trange(num_rollouts, desc="Evaluating...", leave=False):
            if conservative and self.rtg_model is not None:
                if return_target_return:
                    epi_ret, epi_len, epi_cost, epi_target_return = self.rollout(self.model, self.env,
                                                        target_return, target_cost, keep_ctg_positive=keep_ctg_positive,
                                                        conservative=True, rtg_model=self.rtg_model, need_rescale=need_rescale, deterministic=deterministic, return_target_return=return_target_return)
                else:
                    epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env,
                                                        target_return, target_cost, keep_ctg_positive=keep_ctg_positive,
                                                        conservative=True, rtg_model=self.rtg_model, need_rescale=need_rescale, deterministic=deterministic)
            else:
                if return_target_return:
                    epi_ret, epi_len, epi_cost, epi_target_return = self.rollout(self.model, self.env,
                                                            target_return, target_cost, keep_ctg_positive=keep_ctg_positive, need_rescale=need_rescale, deterministic=deterministic, return_target_return=return_target_return)
                else:
                    epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env,
                                                            target_return, target_cost, keep_ctg_positive=keep_ctg_positive, need_rescale=need_rescale, deterministic=deterministic)
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
            if return_target_return:
                episode_target_returns.append(epi_target_return/epi_len)
        self.model.train()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.train()
        if return_target_return:
            return np.mean(episode_rets) / self.reward_scale, np.mean(episode_costs) / self.cost_scale, np.mean(episode_lens), np.mean(episode_target_returns) / self.reward_scale
        return np.mean(episode_rets) / self.reward_scale, np.mean(
            episode_costs) / self.cost_scale, np.mean(episode_lens)
    

    @torch.no_grad()
    def rollout(
        self,
        model: CDT,
        env: gym.Env,
        target_return: float,
        target_cost: float,
        keep_ctg_positive: bool = False,
        conservative: bool = False,
        rtg_model = None,
        need_rescale = False,
        deterministic = False,
        return_rollout = False,
        return_target_return = False
    ) -> Tuple[float, float]:
        """
        Evaluates the performance of the model on a single episode.
        """
        states = torch.zeros(1,
                             model.episode_len + 1,
                             model.state_dim,
                             dtype=torch.float,
                             device=self.device)
        actions = torch.zeros(1,
                              model.episode_len,
                              model.action_dim,
                              dtype=torch.float,
                              device=self.device)
        returns = torch.zeros(1,
                              model.episode_len + 1,
                              dtype=torch.float,
                              device=self.device)
        costs = torch.zeros(1,
                            model.episode_len + 1,
                            dtype=torch.float,
                            device=self.device)
        time_steps = torch.arange(model.episode_len,
                                  dtype=torch.long,
                                  device=self.device)
        time_steps = time_steps.view(1, -1)

        if return_rollout:
            obs_ls = []
            action_ls = []
            next_obs_ls = []
            rewards_ls = []
            costs_ls = []

        obs, info = env.reset()
        
        if self.state_encoder is not None:
            states[:, 0] = self.state_encoder.encode(torch.as_tensor(obs, dtype=torch.float)).to(self.device)
        else:
            states[:, 0] = torch.as_tensor(obs, device=self.device)

        costs[:, 0] = torch.as_tensor(target_cost, device=self.device)
        # print(states.shape,returns.shape,costs.shape)
        # assert False
        if conservative and rtg_model is not None:
            rtg_preds, rtg_dis, std = rtg_model(
                costs[:,0],
                states = states[:,0]
            )
            if deterministic:
                target_return = min(rtg_preds.squeeze().item(), target_return)
            else:
                target_return = min(torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item(), target_return)
            if need_rescale:
                target_return = target_return * 100
        returns[:, 0] = torch.as_tensor(target_return, device=self.device)

        self.rtg_sample_quantile_step = (self.rtg_sample_quantile_start - self.rtg_sample_quantile_end)/target_cost

        epi_cost = torch.tensor(np.array([target_cost]),
                                dtype=torch.float,
                                device=self.device)

        # cannot step higher than model episode len, as timestep embeddings will crash
        episode_ret, episode_cost, episode_len = 0.0, 0.0, 0
        episode_target_return = target_return
        for step in range(model.episode_len):
            # first select history up to step, then select last seq_len states,
            # step + 1 as : operator is not inclusive, last action is dummy with zeros
            # (as model will predict last, actual last values are not important) # fix this noqa!!!
            s = states[:, :step + 1][:, -model.seq_len:]  # noqa
            a = actions[:, :step + 1][:, -model.seq_len:]  # noqa
            r = returns[:, :step + 1][:, -model.seq_len:]  # noqa
            c = costs[:, :step + 1][:, -model.seq_len:]  # noqa
            t = time_steps[:, :step + 1][:, -model.seq_len:]  # noqa
            # if conservative and c[0][-1]<=0:
            #     acts, _, _ = safe_conservative_model(s, a, r, c, t, None, epi_cost)
            # else:
            acts, _, _ = model(s, a, r, c, t, None, None, epi_cost, False)
            if self.stochastic:
                acts = acts.mean
            acts = acts.clamp(-self.max_action, self.max_action)
            # act = acts[0, -1].cpu().numpy()
            act = acts[0,-1]
            if self.action_encoder is not None:
                if not self.train_action_encoder:
                    act_real = self.action_encoder.decode(act.cpu())
                    act = self.action_encoder.encode(act_real).cpu().numpy()
                    act_real = act_real.cpu().numpy()
                else:
                    # act_real = self.action_encoder.decode(act)
                    # # act = self.action_encoder.encode(act_real).cpu().numpy()
                    # act = act_real.cpu().numpy()
                    # act_real = act_real.cpu().numpy()
                    act = act.cpu().numpy()
            else:
                act = act.cpu().numpy()
            # act = self.get_ensemble_action(1, model, s, a, r, c, t, epi_cost)
            
            if self.action_encoder is not None and not self.train_action_encoder:
                obs_next, reward, terminated, truncated, info = env.step(act_real)
                real_act = act_real
            else:
                obs_next, reward, terminated, truncated, info = env.step(act)
                real_act = act
            if self.cost_reverse:
                cost = (1.0 - info["cost"]) * self.cost_scale
            else:
                cost = info["cost"] * self.cost_scale
            if return_rollout:
                obs_ls.append(obs)
                action_ls.append(real_act)
                next_obs_ls.append(obs_next)
                rewards_ls.append(reward)
                costs_ls.append(cost)
            if cost!=0:
                self.rtg_sample_quantile = max(self.rtg_sample_quantile-self.rtg_sample_quantile_step,self.rtg_sample_quantile_end)
            # at step t, we predict a_t, get s_{t + 1}, r_{t + 1}
            actions[:, step] = torch.as_tensor(act)
            if self.state_encoder is not None:
                states[:, step + 1] = self.state_encoder.encode(torch.as_tensor(obs_next,dtype=torch.float)).to(self.device)
            else:
                states[:, step + 1] = torch.as_tensor(obs_next)
            if keep_ctg_positive:
                costs[:, step + 1] = torch.clamp(torch.as_tensor(costs[:, step] - cost),min=0.0)
            else:
                costs[:, step + 1] = torch.as_tensor(costs[:, step] - cost)
            
            if conservative and rtg_model is not None and (cost!=0 or self.rtg_update_every_step):
                rtg_preds, rtg_dis, std = rtg_model(
                    costs[:,step+1],
                    states = states[:,step+1]
                )
                if deterministic:
                    target_return = rtg_preds.squeeze().item()
                else:
                    target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
                if need_rescale:
                    target_return = target_return * 100
                org_next_rtg = returns[:, step].squeeze().item() - reward
                true_next_rtg = max(min(org_next_rtg,target_return),0.0)
                returns[:, step + 1] = torch.as_tensor(true_next_rtg)
            else:
                returns[:, step + 1] = torch.as_tensor(returns[:, step] - reward)
            episode_target_return += returns[0, step + 1].item()

            obs = obs_next

            episode_ret += reward
            episode_len += 1
            episode_cost += info["cost"]

            if terminated or truncated:
                break

        if return_rollout:
            return obs_ls, action_ls, next_obs_ls, rewards_ls, costs_ls
        else:
            if return_target_return:
                return episode_ret, episode_len, episode_cost, episode_target_return
            else:
                return episode_ret, episode_len, episode_cost

    def get_ensemble_action(self, size: int, model, s, a, r, c, t, epi_cost):
        # [size, seq_len, state_dim]
        s = torch.repeat_interleave(s, size, 0)
        # [size, seq_len, act_dim]
        a = torch.repeat_interleave(a, size, 0)
        # [size, seq_len]
        r = torch.repeat_interleave(r, size, 0)
        c = torch.repeat_interleave(c, size, 0)
        t = torch.repeat_interleave(t, size, 0)
        epi_cost = torch.repeat_interleave(epi_cost, size, 0)

        acts, _, _ = model(s, a, r, c, t, None, epi_cost)
        if self.stochastic:
            acts = acts.mean

        # [size, seq_len, act_dim]
        acts = torch.mean(acts, dim=0, keepdim=True)
        acts = acts.clamp(-self.max_action, self.max_action)
        act = acts[0, -1].cpu().numpy()
        return act

    def collect_random_rollouts(self, model, env):
        episode_rets = []
        obs_ls = []
        action_ls = []
        next_obs_ls = []
        rewards_ls = []
        costs_ls = []
        obs, info = env.reset()
        episode_ret = 0.0
        for step in range(model.episode_len):
            act = env.action_space.sample()
            obs_next, reward, terminated, truncated, info = env.step(act)
            if self.cost_reverse:
                cost = (1.0 - info["cost"]) * self.cost_scale
            else:
                cost = info["cost"] * self.cost_scale
            obs_ls.append(obs)
            action_ls.append(act)
            next_obs_ls.append(obs_next)
            rewards_ls.append(reward)
            costs_ls.append(cost)
            obs = obs_next
            episode_ret += reward
            if terminated or truncated:
                break
        episode_rets.append(episode_ret)
        return obs_ls, action_ls, next_obs_ls, rewards_ls, costs_ls


class MTCDTTrainer:
    """
    Multi-task Constrained Decision Transformer Trainer

    """

    def __init__(
            self,
            model,
            env_ls,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            weight_decay: float = 1e-4,
            betas: Tuple[float, ...] = (0.9, 0.999),
            clip_grad: float = 0.25,
            lr_warmup_steps: int = 10000,
            reward_scale: float = 1.0,
            cost_scale: float = 1.0,
            loss_cost_weight: float = 0.0,
            loss_state_weight: float = 0.0,
            cost_reverse: bool = False,
            no_entropy: bool = False,
            device="cpu",
            rtg_model = None,
            rtg_sample_num: int = 1000,
            rtg_sample_quantile: float = 0.8,
            rtg_sample_quantile_end: float = 0.5,
            rtg_update_every_step: bool = True) -> None:
        self.model = model
        self.logger = logger
        self.env_ls = env_ls
        self.clip_grad = clip_grad
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.device = device
        self.cost_weight = loss_cost_weight
        self.state_weight = loss_state_weight
        self.cost_reverse = cost_reverse
        self.no_entropy = no_entropy

        # self.state_encoder = state_encoder
        # self.action_encoder = action_encoder
        # self.train_action_encoder = train_action_encoder
        # self.action_encoder_logger = action_encoder_logger

        # self.rtg_prior_model = rtg_prior_model
        # if self.rtg_prior_model is not None:
        #     self.rtg_prior_model.eval()
        # self.rtg_posterior_model = rtg_posterior_model
        # if self.rtg_posterior_model is not None:
        #     self.rtg_posterior_model.eval()
        self.rtg_model = rtg_model
        if self.rtg_model is not None:
            self.rtg_model.eval()
        self.rtg_sample_num = rtg_sample_num
        self.rtg_sample_quantile = rtg_sample_quantile
        self.rtg_sample_quantile_start = rtg_sample_quantile
        self.rtg_sample_quantile_end = rtg_sample_quantile_end
        self.rtg_update_every_step = rtg_update_every_step
        self.optim = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=betas,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optim,
            lambda steps: min((steps + 1) / lr_warmup_steps, 1),
        )
        self.stochastic = self.model.stochastic
        if self.stochastic:
            self.log_temperature_optimizer = torch.optim.Adam(
                [self.model.cdt.log_temperature],
                lr=1e-4,
                betas=[0.9, 0.999],
            )
        self.max_action = self.model.max_action

        self.beta_dist = Beta(torch.tensor(2, dtype=torch.float, device=self.device),
                              torch.tensor(5, dtype=torch.float, device=self.device))

    # currently without prompt
    def train_one_step(self, states, actions, returns, costs_return, time_steps, mask,
                       episode_cost, costs, task_name, task_id, env_id, prompt=None):
        # True value indicates that the corresponding key value will be ignored
        # if self.state_encoder is not None:
        #     with torch.no_grad():
        #         states = self.state_encoder.encode(states)
        # if self.action_encoder is not None:
        #     with torch.no_grad():
        #         actions = self.action_encoder.encode(actions)
        padding_mask = ~mask.to(torch.bool)

        # if self.train_action_encoder and self.action_encoder is not None:
        #     org_actions = copy.deepcopy(actions)
        #     # input_actions = self.action_encoder.encode(actions)
        #     actions = self.action_encoder.encode(actions)

        action_preds, cost_preds, state_preds = self.model(
            states=states,
            actions=actions,
            returns_to_go=returns,
            costs_to_go=costs_return,
            time_steps=time_steps,
            task_id=env_id,
            prompt=prompt,
            padding_mask=padding_mask,
            episode_cost=episode_cost
        )

        if self.stochastic:
            log_likelihood = action_preds.log_prob(actions)[mask > 0].mean()
            entropy = action_preds.entropy()[mask > 0].mean()
            entropy_reg = self.model.temperature().detach()
            entropy_reg_item = entropy_reg.item()
            if self.no_entropy:
                entropy_reg = 0.0
                entropy_reg_item = 0.0
            act_loss = -(log_likelihood + entropy_reg * entropy)
            self.logger.store(tab="train/"+task_name,
                            nll=-log_likelihood.item(),
                            ent=entropy.item(),
                            ent_reg=entropy_reg_item)
        else:
            # if self.train_action_encoder and self.action_encoder is not None:
            #     action_preds = self.action_encoder.decode(action_preds)
            #     act_loss = F.mse_loss(action_preds, org_actions, reduction="none")
            # else:
            act_loss = F.mse_loss(action_preds, actions.detach(), reduction="none")
            # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
            act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        # cost_preds: [batch_size * seq_len, 2], costs: [batch_size * seq_len]
        cost_preds = cost_preds.reshape(-1, 2)
        costs = costs.flatten().long().detach()
        cost_loss = F.nll_loss(cost_preds, costs, reduction="none")
        # cost_loss = F.mse_loss(cost_preds, costs.detach(), reduction="none")
        cost_loss = (cost_loss * mask.flatten()).mean()
        # compute the accuracy, 0 value, 1 indice, [batch_size, seq_len]
        pred = cost_preds.data.max(dim=1)[1]
        correct = pred.eq(costs.data.view_as(pred)) * mask.flatten()
        correct = correct.sum()
        total_num = mask.sum()
        acc = correct / total_num

        # [batch_size, seq_len, state_dim]
        if self.model.prompt_concat:
            state_loss = F.mse_loss(state_preds[:, :-1],
                                    torch.cat([states[:, 1:],prompt[:, 1:]],dim=-1).detach(),
                                    reduction="none")
        else:
            state_loss = F.mse_loss(state_preds[:, :-1],
                                    states[:, 1:].detach(),
                                    reduction="none")
        state_loss = (state_loss * mask[:, :-1].unsqueeze(-1)).mean()

        loss = act_loss + self.cost_weight * cost_loss + self.state_weight * state_loss

        self.optim.zero_grad()
        loss.backward()
        if self.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            # if self.action_encoder is not None and self.train_action_encoder:
            #     torch.nn.utils.clip_grad_norm_(self.action_encoder.parameters(), self.clip_grad)
        self.optim.step()

        if self.stochastic:
            self.log_temperature_optimizer.zero_grad()
            temperature_loss = (self.model.temperature() *
                                (entropy - self.model.target_entropy[task_id]).detach())
            temperature_loss.backward()
            self.log_temperature_optimizer.step()

        self.scheduler.step()
        self.logger.store(
            tab="train/"+task_name,
            all_loss=loss.item(),
            act_loss=act_loss.item(),
            cost_loss=cost_loss.item(),
            cost_acc=acc.item(),
            state_loss=state_loss.item(),
            train_lr=self.scheduler.get_last_lr()[0],
        )

    def evaluate(self, num_rollouts, target_return, target_cost, task_id, env_id, episode_len, state_dim, action_dim, keep_ctg_positive=False, conservative=False, prompt=None, need_rescale=False, deterministic=False):
        """
        Evaluates the performance of the model on a number of episodes.
        """
        self.model.eval()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.eval()
        episode_rets, episode_costs, episode_lens = [], [], []
        for _ in trange(num_rollouts, desc="Evaluating...", leave=False):
            if conservative and self.rtg_model is not None:
                epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env_ls[task_id],
                                                      target_return, target_cost, env_id, episode_len, state_dim, action_dim, keep_ctg_positive=keep_ctg_positive,
                                                      conservative=True, rtg_model=self.rtg_model,prompt=prompt,need_rescale=need_rescale,deterministic=deterministic)
            else:
                epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env_ls[task_id],
                                                        target_return, target_cost, env_id, episode_len, state_dim, action_dim, keep_ctg_positive=keep_ctg_positive,prompt=prompt,need_rescale=need_rescale,deterministic=deterministic)
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
        self.model.train()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.train()
        return np.mean(episode_rets) / self.reward_scale, np.mean(
            episode_costs) / self.cost_scale, np.mean(episode_lens)
    

    @torch.no_grad()
    def rollout(
        self,
        model: CDT,
        env: gym.Env,
        target_return: float,
        target_cost: float,
        task_id,
        model_episode_len,
        state_dim,
        action_dim,
        keep_ctg_positive: bool = False,
        conservative: bool = False,
        rtg_model = None,
        prompt = None,
        need_rescale = False,
        deterministic = False
    ) -> Tuple[float, float]:
        """
        Evaluates the performance of the model on a single episode.
        """
        states = torch.zeros(1,
                             model_episode_len + 1,
                             state_dim,
                             dtype=torch.float,
                             device=self.device)
        actions = torch.zeros(1,
                              model_episode_len,
                              action_dim,
                              dtype=torch.float,
                              device=self.device)
        returns = torch.zeros(1,
                              model_episode_len + 1,
                              dtype=torch.float,
                              device=self.device)
        costs = torch.zeros(1,
                            model_episode_len + 1,
                            dtype=torch.float,
                            device=self.device)
        none_prefix_prompts = torch.zeros(1,model_episode_len + 1,model.prompt_dim,dtype=torch.float,device=self.device)

        time_steps = torch.arange(model_episode_len,
                                  dtype=torch.long,
                                  device=self.device)
        time_steps = time_steps.view(1, -1)

        obs, info = env.reset()
        
        # if self.state_encoder is not None:
        #     states[:, 0] = self.state_encoder.encode(torch.as_tensor(obs, dtype=torch.float)).to(self.device)
        # else:
        states[:, 0] = torch.as_tensor(obs, device=self.device)

        costs[:, 0] = torch.as_tensor(target_cost, device=self.device)

        if (model.use_prompt and not model.prompt_prefix) or model.prompt_concat:
            none_prefix_prompts[:,0] = prompt.reshape(-1)
        # print(states.shape,returns.shape,costs.shape)
        # assert False
        # if conservative and rtg_prior_model is not None:
        #     rtg_preds, rtg_dis, std = rtg_prior_model(
        #         costs[:,0],
        #         states = states[:,0]
        #     )
        #     target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
        if conservative and rtg_model is not None:
            # print(costs[:,0].shape, states[:,0].shape, prompt.shape)
            rtg_preds, rtg_dis, std = rtg_model(
                costs[:,0],
                states = states[:,0],
                prompts = prompt,
                task_id = task_id

            )
            if deterministic:
                target_return = min(rtg_preds.squeeze().item(), target_return)
            else:
                target_return = min(torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item(), target_return)
            if need_rescale:
                target_return = target_return * 100
        returns[:, 0] = torch.as_tensor(target_return, device=self.device)

        self.rtg_sample_quantile_step = (self.rtg_sample_quantile_start - self.rtg_sample_quantile_end)/target_cost

        epi_cost = torch.tensor(np.array([target_cost]),
                                dtype=torch.float,
                                device=self.device)

        # cannot step higher than model episode len, as timestep embeddings will crash
        episode_ret, episode_cost, episode_len = 0.0, 0.0, 0
        org_prompt = prompt
        for step in range(model_episode_len):
            # first select history up to step, then select last seq_len states,
            # step + 1 as : operator is not inclusive, last action is dummy with zeros
            # (as model will predict last, actual last values are not important) # fix this noqa!!!
            s = states[:, :step + 1][:, -model.seq_len:]  # noqa
            a = actions[:, :step + 1][:, -model.seq_len:]  # noqa
            r = returns[:, :step + 1][:, -model.seq_len:]  # noqa
            c = costs[:, :step + 1][:, -model.seq_len:]  # noqa
            t = time_steps[:, :step + 1][:, -model.seq_len:]  # noqa
            prmpt = none_prefix_prompts[:,:step+1][:, -model.seq_len:] 
            # if conservative and c[0][-1]<=0:
            #     acts, _, _ = safe_conservative_model(s, a, r, c, t, None, epi_cost)
            # else:
            if model.prompt_prefix and not model.prompt_concat:
                prompt = prompt.reshape(1,1,-1).expand(s.shape[0],-1,-1)
            else:
                prompt = prmpt
            acts, _, _ = model(s, a, r, c, t, task_id, prompt, None, epi_cost)
            if self.stochastic:
                acts = acts.mean
            acts = acts.clamp(-self.max_action, self.max_action)
            # act = acts[0, -1].cpu().numpy()
            act = acts[0,-1]
            # if self.action_encoder is not None:
            #     if not self.train_action_encoder:
            #         act_real = self.action_encoder.decode(act.cpu())
            #         act = self.action_encoder.encode(act_real).cpu().numpy()
            #         act_real = act_real.cpu().numpy()
            #     else:
            #         # act_real = self.action_encoder.decode(act)
            #         # # act = self.action_encoder.encode(act_real).cpu().numpy()
            #         # act = act_real.cpu().numpy()
            #         # act_real = act_real.cpu().numpy()
            #         act = act.cpu().numpy()
            # else:
            act = act.cpu().numpy()
            # act = self.get_ensemble_action(1, model, s, a, r, c, t, epi_cost)
            
            # if self.action_encoder is not None and not self.train_action_encoder:
            #     obs_next, reward, terminated, truncated, info = env.step(act_real)
            # else:
            obs_next, reward, terminated, truncated, info = env.step(act)
            if self.cost_reverse:
                cost = (1.0 - info["cost"]) * self.cost_scale
            else:
                cost = info["cost"] * self.cost_scale
            if cost!=0:
                self.rtg_sample_quantile = max(self.rtg_sample_quantile-self.rtg_sample_quantile_step,self.rtg_sample_quantile_end)
            # at step t, we predict a_t, get s_{t + 1}, r_{t + 1}
            actions[:, step] = torch.as_tensor(act)
            # if self.state_encoder is not None:
            #     states[:, step + 1] = self.state_encoder.encode(torch.as_tensor(obs_next,dtype=torch.float)).to(self.device)
            # else:
            states[:, step + 1] = torch.as_tensor(obs_next)
            none_prefix_prompts[:, step + 1] = org_prompt.reshape(-1)
            if keep_ctg_positive:
                costs[:, step + 1] = torch.clamp(torch.as_tensor(costs[:, step] - cost),min=0.0)
            else:
                costs[:, step + 1] = torch.as_tensor(costs[:, step] - cost)
            
            if conservative and rtg_model is not None and (cost!=0 or self.rtg_update_every_step):
                rtg_preds, rtg_dis, std = rtg_model(
                    costs[:,step+1],
                    states = states[:,step+1],
                    prompts = org_prompt,
                    task_id = task_id
                )
                if deterministic:
                    target_return = rtg_preds.squeeze().item()
                else:
                    target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
                if need_rescale:
                    target_return = target_return * 100
                org_next_rtg = returns[:, step].squeeze().item() - reward
                true_next_rtg = max(min(org_next_rtg,target_return),0.0)
                returns[:, step + 1] = torch.as_tensor(true_next_rtg)
            else:
                returns[:, step + 1] = torch.as_tensor(returns[:, step] - reward)

            obs = obs_next

            episode_ret += reward
            episode_len += 1
            episode_cost += info["cost"]

            if terminated or truncated:
                break

        return episode_ret, episode_len, episode_cost

    def get_ensemble_action(self, size: int, model, s, a, r, c, t, epi_cost):
        # [size, seq_len, state_dim]
        s = torch.repeat_interleave(s, size, 0)
        # [size, seq_len, act_dim]
        a = torch.repeat_interleave(a, size, 0)
        # [size, seq_len]
        r = torch.repeat_interleave(r, size, 0)
        c = torch.repeat_interleave(c, size, 0)
        t = torch.repeat_interleave(t, size, 0)
        epi_cost = torch.repeat_interleave(epi_cost, size, 0)

        acts, _, _ = model(s, a, r, c, t, None, epi_cost)
        if self.stochastic:
            acts = acts.mean

        # [size, seq_len, act_dim]
        acts = torch.mean(acts, dim=0, keepdim=True)
        acts = acts.clamp(-self.max_action, self.max_action)
        act = acts[0, -1].cpu().numpy()
        return act

    def collect_random_rollouts(self, num_rollouts):
        episode_rets = []
        for _ in range(num_rollouts):
            obs, info = self.env.reset()
            episode_ret = 0.0
            for step in range(self.model.episode_len):
                act = self.env.action_space.sample()
                obs_next, reward, terminated, truncated, info = self.env.step(act)
                obs = obs_next
                episode_ret += reward
                if terminated or truncated:
                    break
            episode_rets.append(episode_ret)
        return np.mean(episode_rets)
    
class PromptCDTTrainer:
    """
    Prompt Constrained Decision Transformer Trainer

    """

    def __init__(
            self,
            model,
            env_ls,
            logger: WandbLogger = DummyLogger(),
            # training params
            learning_rate: float = 1e-4,
            weight_decay: float = 1e-4,
            betas: Tuple[float, ...] = (0.9, 0.999),
            clip_grad: float = 0.25,
            lr_warmup_steps: int = 10000,
            reward_scale: float = 1.0,
            cost_scale: float = 1.0,
            loss_cost_weight: float = 0.0,
            loss_state_weight: float = 0.0,
            cost_reverse: bool = False,
            no_entropy: bool = False,
            device="cpu",
            rtg_model = None,
            rtg_sample_num: int = 1000,
            rtg_sample_quantile: float = 0.8,
            rtg_sample_quantile_end: float = 0.5,
            rtg_update_every_step: bool = True) -> None:
        self.model = model
        self.logger = logger
        self.env_ls = env_ls
        self.clip_grad = clip_grad
        self.reward_scale = reward_scale
        self.cost_scale = cost_scale
        self.device = device
        self.cost_weight = loss_cost_weight
        self.state_weight = loss_state_weight
        self.cost_reverse = cost_reverse
        self.no_entropy = no_entropy

        # self.state_encoder = state_encoder
        # self.action_encoder = action_encoder
        # self.train_action_encoder = train_action_encoder
        # self.action_encoder_logger = action_encoder_logger

        # self.rtg_prior_model = rtg_prior_model
        # if self.rtg_prior_model is not None:
        #     self.rtg_prior_model.eval()
        # self.rtg_posterior_model = rtg_posterior_model
        # if self.rtg_posterior_model is not None:
        #     self.rtg_posterior_model.eval()
        self.rtg_model = rtg_model
        if self.rtg_model is not None:
            self.rtg_model.eval()
        self.rtg_sample_num = rtg_sample_num
        self.rtg_sample_quantile = rtg_sample_quantile
        self.rtg_sample_quantile_start = rtg_sample_quantile
        self.rtg_sample_quantile_end = rtg_sample_quantile_end
        self.rtg_update_every_step = rtg_update_every_step
        self.optim = torch.optim.AdamW(
            self.model.parameters(),
            lr=learning_rate,
            weight_decay=weight_decay,
            betas=betas,
        )

        self.scheduler = torch.optim.lr_scheduler.LambdaLR(
            self.optim,
            lambda steps: min((steps + 1) / lr_warmup_steps, 1),
        )
        self.stochastic = self.model.stochastic
        if self.stochastic:
            self.log_temperature_optimizer = torch.optim.Adam(
                [self.model.cdt.log_temperature],
                lr=1e-4,
                betas=[0.9, 0.999],
            )
        self.max_action = self.model.max_action

        self.beta_dist = Beta(torch.tensor(2, dtype=torch.float, device=self.device),
                              torch.tensor(5, dtype=torch.float, device=self.device))

    # currently without prompt
    def train_one_step(self, states, actions, returns, costs_return, prompt_states, prompt_actions, prompt_returns, prompt_costs_return, time_steps, mask,
                       episode_cost, costs, task_name, task_id, env_id):
        # True value indicates that the corresponding key value will be ignored
        # if self.state_encoder is not None:
        #     with torch.no_grad():
        #         states = self.state_encoder.encode(states)
        # if self.action_encoder is not None:
        #     with torch.no_grad():
        #         actions = self.action_encoder.encode(actions)
        padding_mask = ~mask.to(torch.bool)

        # if self.train_action_encoder and self.action_encoder is not None:
        #     org_actions = copy.deepcopy(actions)
        #     # input_actions = self.action_encoder.encode(actions)
        #     actions = self.action_encoder.encode(actions)

        action_preds, cost_preds, state_preds = self.model(
            states=states,
            actions=actions,
            returns_to_go=returns,
            costs_to_go=costs_return,
            time_steps=time_steps,
            task_id=env_id,
            prompt_states=prompt_states,
            prompt_actions=prompt_actions,
            prompt_returns_to_go=prompt_returns,
            prompt_costs_to_go=prompt_costs_return,
            padding_mask=padding_mask,
            episode_cost=episode_cost
        )

        if self.stochastic:
            log_likelihood = action_preds.log_prob(actions)[mask > 0].mean()
            entropy = action_preds.entropy()[mask > 0].mean()
            entropy_reg = self.model.temperature().detach()
            entropy_reg_item = entropy_reg.item()
            if self.no_entropy:
                entropy_reg = 0.0
                entropy_reg_item = 0.0
            act_loss = -(log_likelihood + entropy_reg * entropy)
            self.logger.store(tab="train/"+task_name,
                            nll=-log_likelihood.item(),
                            ent=entropy.item(),
                            ent_reg=entropy_reg_item)
        else:
            # if self.train_action_encoder and self.action_encoder is not None:
            #     action_preds = self.action_encoder.decode(action_preds)
            #     act_loss = F.mse_loss(action_preds, org_actions, reduction="none")
            # else:
            act_loss = F.mse_loss(action_preds, actions.detach(), reduction="none")
            # [batch_size, seq_len, action_dim] * [batch_size, seq_len, 1]
            act_loss = (act_loss * mask.unsqueeze(-1)).mean()

        # cost_preds: [batch_size * seq_len, 2], costs: [batch_size * seq_len]
        cost_preds = cost_preds.reshape(-1, 2)
        costs = costs.flatten().long().detach()
        cost_loss = F.nll_loss(cost_preds, costs, reduction="none")
        # cost_loss = F.mse_loss(cost_preds, costs.detach(), reduction="none")
        cost_loss = (cost_loss * mask.flatten()).mean()
        # compute the accuracy, 0 value, 1 indice, [batch_size, seq_len]
        pred = cost_preds.data.max(dim=1)[1]
        correct = pred.eq(costs.data.view_as(pred)) * mask.flatten()
        correct = correct.sum()
        total_num = mask.sum()
        acc = correct / total_num

        # [batch_size, seq_len, state_dim]
        # if self.model.prompt_concat:
        #     state_loss = F.mse_loss(state_preds[:, :-1],
        #                             torch.cat([states[:, 1:],prompt[:, 1:]],dim=-1).detach(),
        #                             reduction="none")
        # else:
        state_loss = F.mse_loss(state_preds[:, :-1],
                                states[:, 1:].detach(),
                                reduction="none")
        state_loss = (state_loss * mask[:, :-1].unsqueeze(-1)).mean()

        loss = act_loss + self.cost_weight * cost_loss + self.state_weight * state_loss

        self.optim.zero_grad()
        loss.backward()
        if self.clip_grad is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad)
            # if self.action_encoder is not None and self.train_action_encoder:
            #     torch.nn.utils.clip_grad_norm_(self.action_encoder.parameters(), self.clip_grad)
        self.optim.step()

        if self.stochastic:
            self.log_temperature_optimizer.zero_grad()
            temperature_loss = (self.model.temperature() *
                                (entropy - self.model.target_entropy[task_id]).detach())
            temperature_loss.backward()
            self.log_temperature_optimizer.step()

        self.scheduler.step()
        self.logger.store(
            tab="train/"+task_name,
            all_loss=loss.item(),
            act_loss=act_loss.item(),
            cost_loss=cost_loss.item(),
            cost_acc=acc.item(),
            state_loss=state_loss.item(),
            train_lr=self.scheduler.get_last_lr()[0],
        )

    def evaluate(self, num_rollouts, target_return, target_cost, task_id, env_id, episode_len, state_dim, action_dim, prompt_states, prompt_actions, prompt_returns_to_go, prompt_costs_to_go, keep_ctg_positive=False, conservative=False):
        """
        Evaluates the performance of the model on a number of episodes.
        """
        self.model.eval()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.eval()
        episode_rets, episode_costs, episode_lens = [], [], []
        for _ in trange(num_rollouts, desc="Evaluating...", leave=False):
            if conservative and self.rtg_model is not None:
                epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env_ls[task_id],
                                                      target_return, target_cost, env_id, episode_len, state_dim, action_dim, prompt_states, prompt_actions, prompt_returns_to_go, prompt_costs_to_go, keep_ctg_positive=keep_ctg_positive,
                                                      conservative=True, rtg_model=self.rtg_model)
            else:
                epi_ret, epi_len, epi_cost = self.rollout(self.model, self.env_ls[task_id],
                                                        target_return, target_cost, env_id, episode_len, state_dim, action_dim, prompt_states, prompt_actions, prompt_returns_to_go, prompt_costs_to_go, keep_ctg_positive=keep_ctg_positive)
            episode_rets.append(epi_ret)
            episode_lens.append(epi_len)
            episode_costs.append(epi_cost)
        self.model.train()
        # if self.train_action_encoder and self.action_encoder is not None:
        #     self.action_encoder.train()
        return np.mean(episode_rets) / self.reward_scale, np.mean(
            episode_costs) / self.cost_scale, np.mean(episode_lens)
    

    @torch.no_grad()
    def rollout(
        self,
        model: CDT,
        env: gym.Env,
        target_return: float,
        target_cost: float,
        task_id,
        model_episode_len,
        state_dim,
        action_dim,
        prompt_states, 
        prompt_actions, 
        prompt_returns_to_go, 
        prompt_costs_to_go,
        keep_ctg_positive: bool = False,
        conservative: bool = False,
        rtg_model = None
    ) -> Tuple[float, float]:
        """
        Evaluates the performance of the model on a single episode.
        """
        states = torch.zeros(1,
                             model_episode_len + 1,
                             state_dim,
                             dtype=torch.float,
                             device=self.device)
        actions = torch.zeros(1,
                              model_episode_len,
                              action_dim,
                              dtype=torch.float,
                              device=self.device)
        returns = torch.zeros(1,
                              model_episode_len + 1,
                              dtype=torch.float,
                              device=self.device)
        costs = torch.zeros(1,
                            model_episode_len + 1,
                            dtype=torch.float,
                            device=self.device)
        # none_prefix_prompts = torch.zeros(1,model_episode_len + 1,model.prompt_dim,dtype=torch.float,device=self.device)

        time_steps = torch.arange(model_episode_len,
                                  dtype=torch.long,
                                  device=self.device)
        time_steps = time_steps.view(1, -1)

        obs, info = env.reset()
        
        # if self.state_encoder is not None:
        #     states[:, 0] = self.state_encoder.encode(torch.as_tensor(obs, dtype=torch.float)).to(self.device)
        # else:
        states[:, 0] = torch.as_tensor(obs, device=self.device)

        costs[:, 0] = torch.as_tensor(target_cost, device=self.device)

        # if (model.use_prompt and not model.prompt_prefix) or model.prompt_concat:
        #     none_prefix_prompts[:,0] = prompt.reshape(-1)
        # print(states.shape,returns.shape,costs.shape)
        # assert False
        # if conservative and rtg_prior_model is not None:
        #     rtg_preds, rtg_dis, std = rtg_prior_model(
        #         costs[:,0],
        #         states = states[:,0]
        #     )
        #     target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
        if conservative and rtg_model is not None:
            assert False
            # print(costs[:,0].shape, states[:,0].shape, prompt.shape)
            rtg_preds, rtg_dis, std = rtg_model(
                costs[:,0],
                states = states[:,0],
                prompts = prompt,
                task_id = task_id

            )
            target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
        returns[:, 0] = torch.as_tensor(target_return, device=self.device)

        self.rtg_sample_quantile_step = (self.rtg_sample_quantile_start - self.rtg_sample_quantile_end)/target_cost

        epi_cost = torch.tensor(np.array([target_cost]),
                                dtype=torch.float,
                                device=self.device)

        # cannot step higher than model episode len, as timestep embeddings will crash
        episode_ret, episode_cost, episode_len = 0.0, 0.0, 0
        # org_prompt = prompt
        prompt_len = prompt_states.shape[1]
        for step in range(model_episode_len):
            # first select history up to step, then select last seq_len states,
            # step + 1 as : operator is not inclusive, last action is dummy with zeros
            # (as model will predict last, actual last values are not important) # fix this noqa!!!
            s = states[:, :step + 1][:, -(model.seq_len-prompt_len):]  # noqa
            a = actions[:, :step + 1][:, -(model.seq_len-prompt_len):]  # noqa
            r = returns[:, :step + 1][:, -(model.seq_len-prompt_len):]  # noqa
            c = costs[:, :step + 1][:, -(model.seq_len-prompt_len):]  # noqa
            t = time_steps[:, :step + 1][:, -(model.seq_len-prompt_len):]  # noqa
            # prmpt = none_prefix_prompts[:,:step+1][:, -model.seq_len:] 
            # if conservative and c[0][-1]<=0:
            #     acts, _, _ = safe_conservative_model(s, a, r, c, t, None, epi_cost)
            # else:
            # if model.prompt_prefix and not model.prompt_concat:
            #     prompt = prompt.reshape(1,1,-1).expand(s.shape[0],-1,-1)
            # else:
            #     prompt = prmpt
            # print(s.shape, prompt_states.shape)
            acts, _, _ = model(s, a, r, c, t, task_id, prompt_states, prompt_actions, prompt_returns_to_go, prompt_costs_to_go, None, epi_cost)
            if self.stochastic:
                acts = acts.mean
            acts = acts.clamp(-self.max_action, self.max_action)
            # act = acts[0, -1].cpu().numpy()
            act = acts[0,-1]
            # if self.action_encoder is not None:
            #     if not self.train_action_encoder:
            #         act_real = self.action_encoder.decode(act.cpu())
            #         act = self.action_encoder.encode(act_real).cpu().numpy()
            #         act_real = act_real.cpu().numpy()
            #     else:
            #         # act_real = self.action_encoder.decode(act)
            #         # # act = self.action_encoder.encode(act_real).cpu().numpy()
            #         # act = act_real.cpu().numpy()
            #         # act_real = act_real.cpu().numpy()
            #         act = act.cpu().numpy()
            # else:
            act = act.cpu().numpy()
            # act = self.get_ensemble_action(1, model, s, a, r, c, t, epi_cost)
            
            # if self.action_encoder is not None and not self.train_action_encoder:
            #     obs_next, reward, terminated, truncated, info = env.step(act_real)
            # else:
            obs_next, reward, terminated, truncated, info = env.step(act)
            if self.cost_reverse:
                cost = (1.0 - info["cost"]) * self.cost_scale
            else:
                cost = info["cost"] * self.cost_scale
            if cost!=0:
                self.rtg_sample_quantile = max(self.rtg_sample_quantile-self.rtg_sample_quantile_step,self.rtg_sample_quantile_end)
            # at step t, we predict a_t, get s_{t + 1}, r_{t + 1}
            actions[:, step] = torch.as_tensor(act)
            # if self.state_encoder is not None:
            #     states[:, step + 1] = self.state_encoder.encode(torch.as_tensor(obs_next,dtype=torch.float)).to(self.device)
            # else:
            states[:, step + 1] = torch.as_tensor(obs_next)
            # none_prefix_prompts[:, step + 1] = org_prompt.reshape(-1)
            if keep_ctg_positive:
                costs[:, step + 1] = torch.clamp(torch.as_tensor(costs[:, step] - cost),min=0.0)
            else:
                costs[:, step + 1] = torch.as_tensor(costs[:, step] - cost)
            
            if conservative and rtg_model is not None and (cost!=0 or self.rtg_update_every_step):
                assert False
                rtg_preds, rtg_dis, std = rtg_model(
                    costs[:,step+1],
                    states = states[:,step+1],
                    prompts = org_prompt,
                    task_id = task_id
                )
                target_return = torch.quantile(rtg_dis.sample(torch.Size([self.rtg_sample_num])),self.rtg_sample_quantile,dim=0).squeeze().item()
                org_next_rtg = returns[:, step].squeeze().item() - reward
                true_next_rtg = max(min(org_next_rtg,target_return),0.0)
                returns[:, step + 1] = torch.as_tensor(true_next_rtg)
            else:
                returns[:, step + 1] = torch.as_tensor(returns[:, step] - reward)

            obs = obs_next

            episode_ret += reward
            episode_len += 1
            episode_cost += info["cost"]

            if terminated or truncated:
                break

        return episode_ret, episode_len, episode_cost

    def get_ensemble_action(self, size: int, model, s, a, r, c, t, epi_cost):
        # [size, seq_len, state_dim]
        s = torch.repeat_interleave(s, size, 0)
        # [size, seq_len, act_dim]
        a = torch.repeat_interleave(a, size, 0)
        # [size, seq_len]
        r = torch.repeat_interleave(r, size, 0)
        c = torch.repeat_interleave(c, size, 0)
        t = torch.repeat_interleave(t, size, 0)
        epi_cost = torch.repeat_interleave(epi_cost, size, 0)

        acts, _, _ = model(s, a, r, c, t, None, epi_cost)
        if self.stochastic:
            acts = acts.mean

        # [size, seq_len, act_dim]
        acts = torch.mean(acts, dim=0, keepdim=True)
        acts = acts.clamp(-self.max_action, self.max_action)
        act = acts[0, -1].cpu().numpy()
        return act

    def collect_random_rollouts(self, num_rollouts):
        episode_rets = []
        for _ in range(num_rollouts):
            obs, info = self.env.reset()
            episode_ret = 0.0
            for step in range(self.model.episode_len):
                act = self.env.action_space.sample()
                obs_next, reward, terminated, truncated, info = self.env.step(act)
                obs = obs_next
                episode_ret += reward
                if terminated or truncated:
                    break
            episode_rets.append(episode_ret)
        return np.mean(episode_rets)