import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
from torch import einsum
from einops import rearrange, repeat
from einops.layers.torch import Rearrange, Reduce
from torch.distributions import Categorical
from algorithms.utils.util import check, init, exists
from algorithms.utils.transformer_act import discrete_autoregreesive_act
from algorithms.utils.transformer_act import discrete_parallel_act
from algorithms.utils.transformer_act import continuous_autoregreesive_act
from algorithms.utils.transformer_act import continuous_parallel_act
from algorithms.utils.transformer_act import discrete_parallel_disc
from algorithms.utils.transformer_act import continuous_parallel_disc
from algorithms.utils.transformer_act import continuous_autoregreesive_distb_gail
from algorithms.utils.transformer_act import continuous_parallel_act_distb_gail


def init_(m, gain=0.01, activate=False):
    if activate:
        gain = nn.init.calculate_gain('relu')
    return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=gain)


def init_disc(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        nn.init.zeros_(module.bias)


class SelfAttention(nn.Module):

    def __init__(self, n_embd, n_head, n_agent, masked=False):
        super(SelfAttention, self).__init__()

        assert n_embd % n_head == 0
        self.masked = masked
        self.n_head = n_head
        # key, query, value projections for all heads
        self.key = init_(nn.Linear(n_embd, n_embd))
        self.query = init_(nn.Linear(n_embd, n_embd))
        self.value = init_(nn.Linear(n_embd, n_embd))
        # output projection
        self.proj = init_(nn.Linear(n_embd, n_embd))
        # if self.masked:
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(n_agent + 1, n_agent + 1))
                             .view(1, 1, n_agent + 1, n_agent + 1))

        self.att_bp = None

    def forward(self, key, value, query):
        B, L, D = query.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(key).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        q = self.query(query).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        v = self.value(value).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)

        # causal attention: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # self.att_bp = F.softmax(att, dim=-1)

        if self.masked:
            att = att.masked_fill(self.mask[:, :, :L, :L] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v  # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
        y = y.transpose(1, 2).contiguous().view(B, L, D)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)
        return y


class EncodeBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent):
        super(EncodeBlock, self).__init__()

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        # self.attn = SelfAttention(n_embd, n_head, n_agent, masked=True)
        self.attn = SelfAttention(n_embd, n_head, n_agent, masked=False)
        self.mlp = nn.Sequential(
            init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(1 * n_embd, n_embd))
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x, x, x))
        x = self.ln2(x + self.mlp(x))
        return x


class DecodeBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent, masked=True):
        super(DecodeBlock, self).__init__()

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.attn1 = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.attn2 = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.mlp = nn.Sequential(
            init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(1 * n_embd, n_embd))
        )

    def forward(self, x, rep_enc):
        x = self.ln1(x + self.attn1(x, x, x))
        x = self.ln2(rep_enc + self.attn2(key=x, value=x, query=rep_enc))
        x = self.ln3(x + self.mlp(x))
        return x


class Encoder(nn.Module):

    def __init__(self, state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state):
        super(Encoder, self).__init__()

        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.encode_state = encode_state
        # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))

        self.state_encoder = nn.Sequential(nn.LayerNorm(state_dim),
                                           init_(nn.Linear(state_dim, n_embd), activate=True), nn.GELU())
        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())

        self.ln = nn.LayerNorm(n_embd)
        self.blocks = nn.Sequential(*[EncodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)])
        # delete critic head to train off line
        # self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                           init_(nn.Linear(n_embd, 1)))

    def forward(self, state, obs):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        if self.encode_state:
            state_embeddings = self.state_encoder(state)
            x = state_embeddings
        else:
            obs_embeddings = self.obs_encoder(obs)
            x = obs_embeddings

        rep = self.blocks(self.ln(x))
        # v_loc = self.head(rep)

        return rep


class Decoder(nn.Module):

    def __init__(self, obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                 action_type='Discrete', dec_actor=False, share_actor=False):
        super(Decoder, self).__init__()

        self.action_dim = action_dim
        self.n_embd = n_embd
        self.dec_actor = dec_actor
        self.share_actor = share_actor
        self.action_type = action_type

        if action_type != 'Discrete':
            log_std = torch.ones(action_dim)
            # log_std = torch.zeros(action_dim)
            self.log_std = torch.nn.Parameter(log_std)
            # self.log_std = torch.nn.Parameter(torch.zeros(action_dim))

        if self.dec_actor:
            if self.share_actor:
                print("mac_dec!!!!!")
                self.mlp = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                         init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                         init_(nn.Linear(n_embd, action_dim)))
            else:
                self.mlp = nn.ModuleList()
                for n in range(n_agent):
                    actor = nn.Sequential(nn.LayerNorm(obs_dim),
                                          init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                          init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                          init_(nn.Linear(n_embd, action_dim)))
                    self.mlp.append(actor)
        else:
            # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))
            if action_type == 'Discrete':
                self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim + 1, n_embd, bias=False), activate=True),
                                                    nn.GELU())
            else:
                self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU())
            self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                             init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())
            self.ln = nn.LayerNorm(n_embd)
            self.blocks = nn.Sequential(*[DecodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)])
            self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                      init_(nn.Linear(n_embd, action_dim)))

    def zero_std(self, device):
        if self.action_type != 'Discrete':
            log_std = torch.zeros(self.action_dim).to(device)
            self.log_std.data = log_std

    # state, action, and return
    def forward(self, action, obs_rep, obs):
        # action: (batch, n_agent, action_dim), one-hot/logits?
        # obs_rep: (batch, n_agent, n_embd)
        if self.dec_actor:
            if self.share_actor:
                logit = self.mlp(obs)
            else:
                logit = []
                for n in range(len(self.mlp)):
                    logit_n = self.mlp[n](obs[:, n, :])
                    logit.append(logit_n)
                logit = torch.stack(logit, dim=1)
        else:
            action_embeddings = self.action_encoder(action)
            x = self.ln(action_embeddings)
            for block in self.blocks:
                x = block(x, obs_rep)
            logit = self.head(x)

        return logit


class Critic(nn.Module):
    def __init__(self, state_dim, obs_dim, n_embd):
        super(Critic, self).__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_embd = n_embd
        self.net = nn.Sequential(
            init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(),
            init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(),
            init_(nn.Linear(n_embd, 1)),
        )

    # critic network should get state that processed by encoder
    def forward(self, rep):
        return self.net(rep)

################# use for gmlp disc model


class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x):
        return self.fn(x) + x


class PreShiftTokens(nn.Module):
    def __init__(self, shifts, fn):
        super().__init__()
        self.fn = fn
        self.shifts = tuple(shifts)

    def forward(self, x, **kwargs):
        if self.shifts == (0,):
            return self.fn(x, **kwargs)

        shifts = self.shifts
        segments = len(shifts)
        feats_per_shift = x.shape[-1] // segments
        splitted = x.split(feats_per_shift, dim=-1)
        segments_to_shift, rest = splitted[:segments], splitted[segments:]
        segments_to_shift = list(map(lambda args: shift(*args), zip(segments_to_shift, shifts)))
        x = torch.cat((*segments_to_shift, *rest), dim=-1)
        return self.fn(x, **kwargs)


class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.fn = fn
        self.norm = nn.LayerNorm(dim)

    def forward(self, x, **kwargs):
        x = self.norm(x)
        return self.fn(x, **kwargs)


class Attention(nn.Module):
    def __init__(self, dim_in, dim_out, dim_inner, causal=False):
        super().__init__()
        self.scale = dim_inner ** -0.5
        self.causal = causal

        self.to_qkv = nn.Linear(dim_in, dim_inner * 3, bias=False)
        self.to_out = nn.Linear(dim_inner, dim_out)

    def forward(self, x):
        device = x.device
        q, k, v = self.to_qkv(x).chunk(3, dim=-1)
        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if self.causal:
            mask = torch.ones(sim.shape[-2:], device=device).triu(1).bool()
            sim.masked_fill_(mask[None, ...], -torch.finfo(q.dtype).max)

        attn = sim.softmax(dim=-1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        return self.to_out(out)


class SpatialGatingUnit(nn.Module):
    def __init__(
            self, dim, dim_seq, causal=False, act=nn.Tanh(), heads=1, init_eps=1e-3, circulant_matrix=False
    ):
        super().__init__()
        dim_out = dim // 2
        self.heads = heads
        self.causal = causal
        self.norm = nn.LayerNorm(dim_out)

        self.act = act

        # parameters

        if circulant_matrix:
            self.circulant_pos_x = nn.Parameter(torch.ones(heads, dim_seq))
            self.circulant_pos_y = nn.Parameter(torch.ones(heads, dim_seq))

        self.circulant_matrix = circulant_matrix
        shape = (heads, dim_seq,) if circulant_matrix else (heads, dim_seq, dim_seq)
        weight = torch.zeros(shape)

        self.weight = nn.Parameter(weight)
        init_eps /= dim_seq
        nn.init.uniform_(self.weight, -init_eps, init_eps)

        self.bias = nn.Parameter(torch.ones(heads, dim_seq))

    def forward(self, x, gate_res=None):
        device, n, h = x.device, x.shape[1], self.heads

        res, gate = x.chunk(2, dim=-1)
        gate = self.norm(gate)

        weight, bias = self.weight, self.bias

        if self.circulant_matrix:
            # build the circulant matrix
            dim_seq = weight.shape[-1]
            weight = F.pad(weight, (0, dim_seq), value=0)
            weight = repeat(weight, '... n -> ... (r n)', r=dim_seq)
            weight = weight[:, :-dim_seq].reshape(h, dim_seq, 2 * dim_seq - 1)
            weight = weight[:, :, (dim_seq - 1):]

            # give circulant matrix absolute position awareness
            pos_x, pos_y = self.circulant_pos_x, self.circulant_pos_y
            weight = weight * rearrange(pos_x, 'h i -> h i ()') * rearrange(pos_y, 'h j -> h () j')

        # 考虑因果关系，加入掩码机制，是否可以不固定长度？
        if self.causal:
            weight, bias = weight[:, :n, :n], bias[:, :n]
            mask = torch.ones(weight.shape[-2:], device=device).triu_(1).bool()
            mask = rearrange(mask, 'i j -> () i j')
            weight = weight.masked_fill(mask, 0.)

        # 将第三维按head拆分，并和第二维交换
        gate = rearrange(gate, 'b n (h d) -> b h n d', h=h)

        # gate = weight * gate + bias
        gate = einsum('b h n d, h m n -> b h m d', gate, weight)
        gate = gate + rearrange(bias, 'h n -> () h n ()')

        # 将head维度换回来，并按head合并
        gate = rearrange(gate, 'b h n d -> b n (h d)')

        if exists(gate_res):
            gate = gate + gate_res

        return self.act(gate) * res


class gMLPBlock(nn.Module):
    def __init__(
            self, dim, dim_ff, seq_len,  heads=1, attn_dim=0,
            causal=False, act=nn.Tanh(), circulant_matrix=False):
        super().__init__()
        self.proj_in = nn.Sequential(
            nn.Linear(dim, dim_ff), nn.GELU(),
        )
        self.attn = Attention(dim, dim_ff // 2, attn_dim, causal) if attn_dim > 0 else None
        self.sgu = SpatialGatingUnit(dim_ff, seq_len, causal, act, heads, circulant_matrix=circulant_matrix)
        self.proj_out = nn.Linear(dim_ff // 2, dim)

    def forward(self, x):
        gate_res = self.attn(x) if exists(self.attn) else None
        x = self.proj_in(x)
        x = self.sgu(x, gate_res=gate_res)
        x = self.proj_out(x)

        return x


################# use for gmlp disc model
class Discriminator(nn.Module):
    def __init__(self, state_dim, obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                 action_type, device, disc_share_value=False,
                 disc_gmlp_dim_ff=2, disc_gmlp_use_causal=False,
                 disc_gmlp_add_embd=False, disc_gmlp_obs_encoder=False):
        super().__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.action_type = action_type
        # reward process method after decoder
        self.disc_share_value = disc_share_value
        # add or cat obs actions embd
        self.disc_gmlp_add_embd = disc_gmlp_add_embd
        self.disc_gmlp_obs_encoder = disc_gmlp_obs_encoder
        # config data type and device
        self.tpdv = dict(dtype=torch.float32, device=device)
        # obs layer norm after encoding
        self.obs_ln = nn.LayerNorm(n_embd)
        self.obs_ln1 = nn.LayerNorm(n_embd) if self.disc_gmlp_obs_encoder else None
        # obs layer norm after encoding
        self.act_ln = nn.LayerNorm(n_embd)
        # config obs gmlp encoder if necessary
        self.obs_gmlp_encoder = nn.Sequential(*[
            Residual(PreNorm(self.n_embd, PreShiftTokens(tuple(range(0, 1)), gMLPBlock(
                dim=self.n_embd, dim_ff=self.n_embd * disc_gmlp_dim_ff,
                seq_len=n_agent, heads=n_head, attn_dim=0,
                causal=False, act=nn.Tanh(),
                circulant_matrix=False)))) for _ in range(n_block)
        ]) if self.disc_gmlp_obs_encoder else None
        # config gmlp blocks
        self.gmlp_input_dim = 2 * n_embd if not self.disc_gmlp_add_embd else n_embd
        self.blocks = nn.Sequential(*[
            Residual(PreNorm(self.gmlp_input_dim, PreShiftTokens(tuple(range(0, 1)), gMLPBlock(
                dim=self.gmlp_input_dim, dim_ff=self.gmlp_input_dim * disc_gmlp_dim_ff,
                seq_len=n_agent, heads=n_head, attn_dim=0,
                causal=disc_gmlp_use_causal, act=nn.Tanh(),
                circulant_matrix=False)))) for _ in range(n_block)
        ])
        # config output layer
        self.head = nn.Sequential(
            nn.LayerNorm(self.gmlp_input_dim),
            nn.Linear(self.gmlp_input_dim, 1),
        )
        # self.encoder = DiscEncoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state, disc_cal_last_loss)
        # self.decoder = DiscDecoder(
        #     obs_dim, action_dim, n_block, n_embd, n_head, n_agent, self.action_type,
        #     dec_actor=dec_actor, share_actor=share_actor, masked=disc_mask_action,
        #     cal_last_loss=disc_cal_last_loss, drop_cross_atten=disc_drop_cross_atten,
        # )
        # init disc parameters
        self.apply(init_disc)

        # obs encoder for gmlp disc
        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())
        # action encoder for gmlp disc
        if action_type == 'Discrete':
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd, bias=False), activate=True), nn.GELU())
        else:
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU())

    # discriminator network should get state that processed by encoder
    def forward(self, states, obs, actions, available_actions=None):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        # action: (batch, n_agent, act_dim)
        # available_actions: (batch, n_agent, act_dim)
        # disc_value: (batch, n_agent, 1)
        obs = self.obs_ln(self.obs_encoder(obs))
        if self.disc_gmlp_obs_encoder:
            obs = self.obs_ln1(self.obs_gmlp_encoder(obs))
        # change action to onehot(if not pre-onehot before forward) before feed into action encoder
        if self.action_type == 'Discrete' and actions.shape[-1] == 1:
            actions = F.one_hot(actions.squeeze(-1).long(), num_classes=self.action_dim).float()
        actions = self.act_ln(self.action_encoder(actions))
        logits = self.blocks(torch.cat([obs, actions], dim=-1) if not self.disc_gmlp_add_embd else obs + actions)
        disc_value = self.head(logits)

        return disc_value

    def get_rewards(self, states, obs, actions):
        logits = self.forward(states, obs, actions)
        if self.disc_share_value:
            disc_value = torch.repeat_interleave(
                torch.sum(-F.logsigmoid(logits), dim=1).unsqueeze(-1), self.n_agent, dim=1
            )
        else:
            disc_value = -F.logsigmoid(logits)

        return disc_value

    def get_rewards_from_logits(self, logits):
        """
        logits: torch.Size([batch, num_agents, 1])
        rewards: (batch * num_agents, )
        """
        # different disc rewards calculate methods
        if self.disc_share_value:
            disc_value = torch.repeat_interleave(
                torch.sum(-F.logsigmoid(logits), dim=1).unsqueeze(-1), self.n_agent, dim=1
            )
        else:
            disc_value = -F.logsigmoid(logits)

        return disc_value.detach().cpu().numpy()  # .reshape(-1)


class MultiAgentTransformerGailGMLP(nn.Module):

    def __init__(self, state_dim, obs_dim, action_dim, n_agent,
                 n_block, n_embd, n_head, disc_inner_dim, encode_state=False,
                 device=torch.device("cpu"), action_type='Discrete',
                 dec_actor=False, share_actor=False, use_gail=False,
                 disc_share_value=False, disc_gmlp_dim_ff=2,
                 disc_gmlp_use_causal=False, disc_gmlp_add_embd=False,
                 disc_gmlp_obs_encoder=False):
        super(MultiAgentTransformerGailGMLP, self).__init__()

        self.n_agent = n_agent
        self.action_dim = action_dim
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.action_type = action_type
        self.device = device

        # if use gail to train model(default to False)
        self.use_gail = use_gail

        # state unused
        state_dim = 37

        self.encoder = Encoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state)
        self.decoder = Decoder(obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                               self.action_type, dec_actor=dec_actor, share_actor=share_actor)
        # critic network is optional according to train algorithm
        self.critic = Critic(state_dim, obs_dim, n_embd) if self.use_gail else None
        # discriminator network is optional according to train algorithm
        self.discriminator = Discriminator(
            state_dim=state_dim, obs_dim=obs_dim, action_dim=action_dim,
            n_block=n_block, n_embd=disc_inner_dim, n_head=n_head, n_agent=n_agent,
            action_type=self.action_type, device=device, disc_share_value=disc_share_value,
            disc_gmlp_dim_ff=disc_gmlp_dim_ff, disc_gmlp_use_causal=disc_gmlp_use_causal,
            disc_gmlp_add_embd=disc_gmlp_add_embd, disc_gmlp_obs_encoder=disc_gmlp_obs_encoder,

        ) if self.use_gail else None
        # copy MAT to current device
        self.to(device)

    def zero_std(self):
        if self.action_type != 'Discrete':
            self.decoder.zero_std(self.device)

    def forward(self, state, obs, action, available_actions=None):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        # action: (batch, n_agent, 1)
        # available_actions: (batch, n_agent, act_dim)

        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)

        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        batch_size = np.shape(state)[0]
        obs_rep = self.encoder(state, obs)
        if self.action_type == 'Discrete':
            action = action.long()
            action_log, entropy = discrete_parallel_act(self.decoder, obs_rep, obs, action, batch_size,
                                                        self.n_agent, self.action_dim, self.tpdv, available_actions)
        else:
            action_log, entropy = continuous_parallel_act(self.decoder, obs_rep, obs, action, batch_size,
                                                          self.n_agent, self.action_dim, self.tpdv)

        return action_log, entropy

    def get_actions(self, state, obs, available_actions=None, deterministic=False):
        # state unused
        ori_shape = np.shape(obs)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        batch_size = np.shape(obs)[0]
        obs_rep = self.encoder(state, obs)
        if self.action_type == "Discrete":
            output_action, output_action_log, output_action_prob = discrete_autoregreesive_act(
                self.decoder, obs_rep, obs, batch_size, self.n_agent, self.action_dim, self.tpdv, available_actions, deterministic)
        else:
            output_action, output_action_log, output_action_prob = continuous_autoregreesive_act(
                self.decoder, obs_rep, obs, batch_size, self.n_agent, self.action_dim, self.tpdv, deterministic)

        return output_action, output_action_log, output_action_prob

    def get_critic_values(self, state, obs):
        """
        state torch.Size([ep_len, agent_num, share_obs_dim])
        obs torch.Size([ep_len, agent_num, obs_dim])
        obs_rep torch.Size([ep_len, agent_num, n_embd])
        action torch.Size([ep_len, agent_num, act_dim])
        """
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        # critic network should get state that processed by encoder
        # obs_rep = self.encoder(state, obs).reshape(obs.shape[0], -1)
        obs_rep = self.encoder(state, obs)
        v_tot = self.critic(obs_rep)

        return v_tot

    def get_discriminator_logit(self, state, obs, action):
        """
        state torch.Size([ep_len, agent_num, share_obs_dim])
        obs torch.Size([ep_len, agent_num, obs_dim])
        obs_rep torch.Size([ep_len, agent_num, n_embd])
        action torch.Size([ep_len, agent_num, act_dim])
        """
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        dis_logit = self.discriminator(state, obs, action)

        return dis_logit

    def get_discriminator_reward(self, state, obs, action):
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        dis_reward = self.discriminator.get_rewards(state, obs, action)

        return dis_reward

    def get_discriminator_rewards_from_logits(self, logits):
        return self.discriminator.get_rewards_from_logits(logits)
