import torch
import torch.nn as nn
from torch.nn import functional as F
import math
import numpy as np
from torch.distributions import Categorical
from algorithms.utils.util import check, init, Multi_Agent_Distribute
from algorithms.utils.transformer_act import discrete_autoregreesive_act
from algorithms.utils.transformer_act import discrete_parallel_act
from algorithms.utils.transformer_act import continuous_autoregreesive_act
from algorithms.utils.transformer_act import continuous_parallel_act
from algorithms.utils.transformer_act import discrete_parallel_disc
from algorithms.utils.transformer_act import continuous_parallel_disc
from algorithms.utils.transformer_act import continuous_autoregreesive_distb_gail
from algorithms.utils.transformer_act import continuous_parallel_act_distb_gail


def init_(m, gain=0.01, activate=False):
    if activate:
        gain = nn.init.calculate_gain('relu')
    return init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), gain=gain)


def init_disc(module):
    if isinstance(module, nn.Linear):
        nn.init.xavier_normal_(module.weight)
        nn.init.zeros_(module.bias)


class SelfAttention(nn.Module):

    def __init__(self, n_embd, n_head, n_agent, masked=False):
        super(SelfAttention, self).__init__()

        assert n_embd % n_head == 0
        self.masked = masked
        self.n_head = n_head
        # key, query, value projections for all heads
        self.key = init_(nn.Linear(n_embd, n_embd))
        self.query = init_(nn.Linear(n_embd, n_embd))
        self.value = init_(nn.Linear(n_embd, n_embd))
        # output projection
        self.proj = init_(nn.Linear(n_embd, n_embd))
        # if self.masked:
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("mask", torch.tril(torch.ones(n_agent + 1, n_agent + 1))
                             .view(1, 1, n_agent + 1, n_agent + 1))

        self.att_bp = None

    def forward(self, key, value, query):
        B, L, D = query.size()

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        k = self.key(key).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        q = self.query(query).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)
        v = self.value(value).view(B, L, self.n_head, D // self.n_head).transpose(1, 2)  # (B, nh, L, hs)

        # causal attention: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L)
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))

        # self.att_bp = F.softmax(att, dim=-1)

        if self.masked:
            att = att.masked_fill(self.mask[:, :, :L, :L] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        y = att @ v  # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs)
        y = y.transpose(1, 2).contiguous().view(B, L, D)  # re-assemble all head outputs side by side

        # output projection
        y = self.proj(y)
        return y


class EncodeBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent):
        super(EncodeBlock, self).__init__()

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        # self.attn = SelfAttention(n_embd, n_head, n_agent, masked=True)
        self.attn = SelfAttention(n_embd, n_head, n_agent, masked=False)
        self.mlp = nn.Sequential(
            init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(1 * n_embd, n_embd))
        )

    def forward(self, x):
        x = self.ln1(x + self.attn(x, x, x))
        x = self.ln2(x + self.mlp(x))
        return x


class DecodeBlock(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent, masked=True):
        super(DecodeBlock, self).__init__()

        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
        self.ln3 = nn.LayerNorm(n_embd)
        self.attn1 = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.attn2 = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.mlp = nn.Sequential(
            init_(nn.Linear(n_embd, 1 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(1 * n_embd, n_embd))
        )

    def forward(self, x, rep_enc):
        x = self.ln1(x + self.attn1(x, x, x))
        x = self.ln2(rep_enc + self.attn2(key=x, value=x, query=rep_enc))
        x = self.ln3(x + self.mlp(x))
        return x


class DecodeBlockNotCrossAtten(nn.Module):
    """ an unassuming Transformer block """

    def __init__(self, n_embd, n_head, n_agent, masked=True):
        super(DecodeBlockNotCrossAtten, self).__init__()

        self.ln1 = nn.LayerNorm(2 * n_embd)
        self.ln2 = nn.LayerNorm(2 * n_embd)
        # self.ln3 = nn.LayerNorm(n_embd)
        self.attn1 = SelfAttention(2 * n_embd, n_head, n_agent, masked=masked)
        # drop cross attention between obs and actions if necessary
        # self.attn2 = SelfAttention(n_embd, n_head, n_agent, masked=masked)
        self.mlp = nn.Sequential(
            init_(nn.Linear(2 * n_embd, 2 * n_embd), activate=True),
            nn.GELU(),
            init_(nn.Linear(2 * n_embd, 2 * n_embd))
        )

    def forward(self, x, rep_enc):
        x = torch.cat([rep_enc, x], dim=-1)
        x = self.ln1(x + self.attn1(x, x, x))
        # x = self.ln2(rep_enc + self.attn2(key=x, value=x, query=rep_enc))
        x = self.ln2(x + self.mlp(x))

        return x


class Encoder(nn.Module):

    def __init__(self, state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state):
        super(Encoder, self).__init__()

        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.encode_state = encode_state
        # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))

        self.state_encoder = nn.Sequential(nn.LayerNorm(state_dim),
                                           init_(nn.Linear(state_dim, n_embd), activate=True), nn.GELU())
        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())

        self.ln = nn.LayerNorm(n_embd)
        self.blocks = nn.Sequential(*[EncodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)])
        # delete critic head to train off line
        # self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                           init_(nn.Linear(n_embd, 1)))

    def forward(self, state, obs):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        if self.encode_state:
            state_embeddings = self.state_encoder(state)
            x = state_embeddings
        else:
            obs_embeddings = self.obs_encoder(obs)
            x = obs_embeddings

        rep = self.blocks(self.ln(x))
        # v_loc = self.head(rep)

        return rep


class Decoder(nn.Module):

    def __init__(self, obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                 action_type='Discrete', dec_actor=False, share_actor=False):
        super(Decoder, self).__init__()

        self.action_dim = action_dim
        self.n_embd = n_embd
        self.dec_actor = dec_actor
        self.share_actor = share_actor
        self.action_type = action_type

        if action_type != 'Discrete':
            log_std = torch.ones(action_dim)
            # log_std = torch.zeros(action_dim)
            self.log_std = torch.nn.Parameter(log_std)
            # self.log_std = torch.nn.Parameter(torch.zeros(action_dim))

        if self.dec_actor:
            if self.share_actor:
                print("mac_dec!!!!!")
                self.mlp = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                         init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                         init_(nn.Linear(n_embd, action_dim)))
            else:
                self.mlp = nn.ModuleList()
                for n in range(n_agent):
                    actor = nn.Sequential(nn.LayerNorm(obs_dim),
                                          init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                          init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                          init_(nn.Linear(n_embd, action_dim)))
                    self.mlp.append(actor)
        else:
            # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))
            if action_type == 'Discrete':
                self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim + 1, n_embd, bias=False), activate=True),
                                                    nn.GELU())
            else:
                self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU())
            self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                             init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())
            self.ln = nn.LayerNorm(n_embd)
            self.blocks = nn.Sequential(*[DecodeBlock(n_embd, n_head, n_agent) for _ in range(n_block)])
            self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
                                      init_(nn.Linear(n_embd, action_dim)))

    def zero_std(self, device):
        if self.action_type != 'Discrete':
            log_std = torch.zeros(self.action_dim).to(device)
            self.log_std.data = log_std

    # state, action, and return
    def forward(self, action, obs_rep, obs):
        # action: (batch, n_agent, action_dim), one-hot/logits?
        # obs_rep: (batch, n_agent, n_embd)
        if self.dec_actor:
            if self.share_actor:
                logit = self.mlp(obs)
            else:
                logit = []
                for n in range(len(self.mlp)):
                    logit_n = self.mlp[n](obs[:, n, :])
                    logit.append(logit_n)
                logit = torch.stack(logit, dim=1)
        else:
            action_embeddings = self.action_encoder(action)
            x = self.ln(action_embeddings)
            for block in self.blocks:
                x = block(x, obs_rep)
            logit = self.head(x)

        return logit


class DiscEncoder(nn.Module):

    def __init__(self, state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state, cal_last_loss=False):
        super(DiscEncoder, self).__init__()

        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.encode_state = encode_state
        self.cal_last_loss = cal_last_loss
        # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))

        self.state_encoder = nn.Sequential(nn.LayerNorm(state_dim),
                                           init_(nn.Linear(state_dim, n_embd), activate=True), nn.GELU())
        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim),
                                         init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())
        # end token for obs if calculate last loss
        # obs_end_token: (1, 1, n_embd)
        self.obs_end_token = nn.Parameter(
            torch.zeros(n_embd, dtype=torch.float32).unsqueeze(0).unsqueeze(0),
            requires_grad=True,
        ) if self.cal_last_loss else None

        # if cal_last_loss, add another position for end of token
        self.ln = nn.LayerNorm(n_embd)
        self.blocks = nn.Sequential(*[EncodeBlock(
            n_embd, n_head, n_agent if not self.cal_last_loss else n_agent + 1
        ) for _ in range(n_block)])
        # delete critic head to train off line
        # self.head = nn.Sequential(init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                           init_(nn.Linear(n_embd, 1)))

    def forward(self, state, obs):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        if self.encode_state:
            state_embeddings = self.state_encoder(state)
            x = state_embeddings
        else:
            obs_embeddings = self.obs_encoder(obs)
            x = obs_embeddings

        # obs_end_token: (1, 1, n_embd) --> (batch, 1, n_embd)
        # x: (batch, n_agent, n_embd)
        if self.cal_last_loss:
            x = torch.cat([
                x, self.obs_end_token.repeat(x.shape[0], 1, 1)
            ], dim=1)
        rep = self.blocks(self.ln(x))
        # v_loc = self.head(rep)

        return rep


class DiscDecoder(nn.Module):

    def __init__(self, obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                 action_type='Discrete', dec_actor=False, share_actor=False, masked=True,
                 cal_last_loss=False, drop_cross_atten=False):
        super(DiscDecoder, self).__init__()

        self.action_dim = action_dim
        self.n_embd = n_embd
        self.dec_actor = dec_actor  # not use
        self.share_actor = share_actor  # not use
        self.action_type = action_type
        self.cal_last_loss = cal_last_loss
        self.drop_cross_atten = drop_cross_atten

        if action_type != 'Discrete':
            log_std = torch.ones(action_dim)
            # log_std = torch.zeros(action_dim)
            self.log_std = torch.nn.Parameter(log_std)
            # self.log_std = torch.nn.Parameter(torch.zeros(action_dim))
        # if self.dec_actor:
        #     if self.share_actor:
        #         print("mac_dec!!!!!")
        #         self.mlp = nn.Sequential(nn.LayerNorm(obs_dim),
        #                                  init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                                  init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                                  init_(nn.Linear(n_embd, action_dim)))
        #     else:
        #         self.mlp = nn.ModuleList()
        #         for n in range(n_agent):
        #             actor = nn.Sequential(nn.LayerNorm(obs_dim),
        #                                   init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                                   init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(), nn.LayerNorm(n_embd),
        #                                   init_(nn.Linear(n_embd, 1)))
        #             self.mlp.append(actor)
        # else:
        # self.agent_id_emb = nn.Parameter(torch.zeros(1, n_agent, n_embd))
        if action_type == 'Discrete':
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd, bias=False), activate=True), nn.GELU())
        else:
            self.action_encoder = nn.Sequential(init_(nn.Linear(action_dim, n_embd), activate=True), nn.GELU())

        # end token for actions if calculate last loss
        # act_end_token: (1, 1, n_embd)
        self.act_end_token = nn.Parameter(
            torch.zeros(n_embd, dtype=torch.float32).unsqueeze(0).unsqueeze(0),
            requires_grad=True,
        ) if self.cal_last_loss else None

        self.obs_encoder = nn.Sequential(nn.LayerNorm(obs_dim), init_(nn.Linear(obs_dim, n_embd), activate=True), nn.GELU())
        self.ln = nn.LayerNorm(n_embd)
        self.blocks = nn.Sequential(*[
            DecodeBlock(
                n_embd, n_head, n_agent if not self.cal_last_loss else n_agent + 1, masked
            ) if not self.drop_cross_atten else
            DecodeBlockNotCrossAtten(
                n_embd, n_head, n_agent if not self.cal_last_loss else n_agent + 1, masked
            ) for _ in range(n_block)
        ])
        self.head_inner_dim = n_embd if not self.drop_cross_atten else 2 * n_embd
        self.head = nn.Sequential(init_(nn.Linear(self.head_inner_dim, self.head_inner_dim), activate=True),
                                  nn.GELU(), nn.LayerNorm(self.head_inner_dim),
                                  init_(nn.Linear(self.head_inner_dim, 1)))

    def zero_std(self, device):
        if self.action_type != 'Discrete':
            log_std = torch.zeros(self.action_dim).to(device)
            self.log_std.data = log_std

    # state, action, and return
    def forward(self, action, obs_rep, obs):
        # action: (batch, n_agent, action_dim), one-hot/logits?
        # obs_rep: (batch, n_agent, n_embd)
        # if self.dec_actor:
        #     if self.share_actor:
        #         logit = self.mlp(obs)
        #     else:
        #         logit = []
        #         for n in range(len(self.mlp)):
        #             logit_n = self.mlp[n](obs[:, n, :])
        #             logit.append(logit_n)
        #         logit = torch.stack(logit, dim=1)
        # else:
        action_embeddings = self.action_encoder(action)
        # act_end_token: (1, 1, n_embd) --> (batch, 1, n_embd)
        # x: (batch, n_agent, n_embd)
        if self.cal_last_loss:
            action_embeddings = torch.cat([
                action_embeddings, self.act_end_token.repeat(action_embeddings.shape[0], 1, 1)
            ], dim=1)
        x = self.ln(action_embeddings)
        for block in self.blocks:
            x = block(x, obs_rep)
        logit = self.head(x)

        return logit


class Critic(nn.Module):
    def __init__(self, state_dim, obs_dim, n_embd):
        super(Critic, self).__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.n_embd = n_embd
        self.net = nn.Sequential(
            init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(),
            init_(nn.Linear(n_embd, n_embd), activate=True), nn.GELU(),
            init_(nn.Linear(n_embd, 1)),
        )

    # critic network should get state that processed by encoder
    def forward(self, rep):
        return self.net(rep)


class Discriminator(nn.Module):
    def __init__(self, state_dim, obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                 action_type, device, encode_state=False, dec_actor=False, share_actor=False,
                 disc_share_value=False, disc_mask_action=True,
                 disc_cal_last_loss=False, disc_drop_cross_atten=False):
        super().__init__()
        self.state_dim = state_dim
        self.obs_dim = obs_dim
        self.action_dim = action_dim
        self.n_embd = n_embd
        self.n_agent = n_agent
        self.action_type = action_type
        # reward process method after decoder
        self.disc_share_value = disc_share_value
        self.disc_cal_last_loss = disc_cal_last_loss
        self.disc_drop_cross_atten = disc_drop_cross_atten
        self.tpdv = dict(dtype=torch.float32, device=device)
        # self.net_in_dim = n_agent * (n_embd + action_dim)
        self.encoder = DiscEncoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state, disc_cal_last_loss)
        self.decoder = DiscDecoder(
            obs_dim, action_dim, n_block, n_embd, n_head, n_agent, self.action_type,
            dec_actor=dec_actor, share_actor=share_actor, masked=disc_mask_action,
            cal_last_loss=disc_cal_last_loss, drop_cross_atten=disc_drop_cross_atten,
        )
        # if disc_use_init:
        #     print('----------------- disc use unique init method -----------------------')
        #     self.apply(init_disc)

    # discriminator network should get state that processed by encoder
    def forward(self, states, obs, actions, available_actions=None):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        # action: (batch, n_agent, act_dim)
        # available_actions: (batch, n_agent, act_dim)
        # disc_value: (batch, n_agent, 1)
        # if available_actions is not None:
        #     available_actions = check(available_actions).to(**self.tpdv)
        obs_rep = self.encoder(states, obs)
        # change action to onehot(if not pre-onehot before forward) before feed into action encoder
        if self.action_type == 'Discrete' and actions.shape[-1] == 1:
            actions = F.one_hot(actions.squeeze(-1).long(), num_classes=self.action_dim).float()
        disc_value = self.decoder(actions, obs_rep, obs)

        return disc_value

    def get_rewards(self, states, obs, actions):
        logits = self.forward(states, obs, actions)
        # drop last token to calculate disc rewards
        if self.disc_cal_last_loss:
            logits = logits[:, :-1, :]
        if self.disc_share_value:
            disc_value = torch.repeat_interleave(
                torch.sum(-F.logsigmoid(logits), dim=1).unsqueeze(-1), self.n_agent, dim=1
            )
        else:
            disc_value = -F.logsigmoid(logits)

        return disc_value

    def get_rewards_from_logits(self, logits):
        """
        logits: torch.Size([batch, num_agents, 1])
        rewards: (batch * num_agents, )
        """
        # drop last token to calculate disc rewards
        if self.disc_cal_last_loss:
            logits = logits[:, :-1, :]
        # different disc rewards calculate methods
        if self.disc_share_value:
            disc_value = torch.repeat_interleave(
                torch.sum(-F.logsigmoid(logits), dim=1).unsqueeze(-1), self.n_agent, dim=1
            )
        else:
            disc_value = -F.logsigmoid(logits)

        return disc_value.detach().cpu().numpy()  # .reshape(-1)


class MultiAgentTransformerGailDec(nn.Module):

    def __init__(self, state_dim, obs_dim, action_dim, n_agent,
                 n_block, n_embd, n_head, disc_inner_dim, encode_state=False,
                 device=torch.device("cpu"), action_type='Discrete',
                 dec_actor=False, share_actor=False, use_gail=False,
                 disc_share_value=False, disc_mask_action=True,
                 disc_cal_last_loss=False, disc_drop_cross_atten=False):
        super(MultiAgentTransformerGailDec, self).__init__()

        self.n_agent = n_agent
        self.action_dim = action_dim
        self.tpdv = dict(dtype=torch.float32, device=device)
        self.action_type = action_type
        self.device = device

        # if use gail to train model(default to False)
        self.use_gail = use_gail

        # state unused
        state_dim = 37

        self.encoder = Encoder(state_dim, obs_dim, n_block, n_embd, n_head, n_agent, encode_state)
        self.decoder = Decoder(obs_dim, action_dim, n_block, n_embd, n_head, n_agent,
                               self.action_type, dec_actor=dec_actor, share_actor=share_actor)
        # critic network is optional according to train algorithm
        self.critic = Critic(state_dim, obs_dim, n_embd) if self.use_gail else None
        # discriminator network is optional according to train algorithm
        self.discriminator = Discriminator(
            state_dim=state_dim, obs_dim=obs_dim, action_dim=action_dim,
            n_block=n_block, n_embd=disc_inner_dim, n_head=n_head, n_agent=n_agent,
            action_type=self.action_type, device=device,
            encode_state=encode_state, dec_actor=dec_actor, share_actor=share_actor,
            disc_share_value=disc_share_value, disc_mask_action=disc_mask_action,
            disc_cal_last_loss=disc_cal_last_loss, disc_drop_cross_atten=disc_drop_cross_atten,
        ) if self.use_gail else None
        # copy MAT to current device
        self.to(device)

    def zero_std(self):
        if self.action_type != 'Discrete':
            self.decoder.zero_std(self.device)

    def forward(self, state, obs, action, available_actions=None):
        # state: (batch, n_agent, state_dim)
        # obs: (batch, n_agent, obs_dim)
        # action: (batch, n_agent, 1)
        # available_actions: (batch, n_agent, act_dim)

        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)

        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        batch_size = np.shape(state)[0]
        obs_rep = self.encoder(state, obs)
        if self.action_type == 'Discrete':
            action = action.long()
            action_log, entropy = discrete_parallel_act(self.decoder, obs_rep, obs, action, batch_size,
                                                        self.n_agent, self.action_dim, self.tpdv, available_actions)
        else:
            action_log, entropy = continuous_parallel_act(self.decoder, obs_rep, obs, action, batch_size,
                                                          self.n_agent, self.action_dim, self.tpdv)

        return action_log, entropy

    def get_actions(self, state, obs, available_actions=None, deterministic=False):
        # state unused
        ori_shape = np.shape(obs)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        if available_actions is not None:
            available_actions = check(available_actions).to(**self.tpdv)

        batch_size = np.shape(obs)[0]
        obs_rep = self.encoder(state, obs)
        if self.action_type == "Discrete":
            output_action, output_action_log, output_action_prob = discrete_autoregreesive_act(
                self.decoder, obs_rep, obs, batch_size, self.n_agent, self.action_dim, self.tpdv, available_actions, deterministic)
        else:
            output_action, output_action_log, output_action_prob = continuous_autoregreesive_act(
                self.decoder, obs_rep, obs, batch_size, self.n_agent, self.action_dim, self.tpdv, deterministic)

        return output_action, output_action_log, output_action_prob

    def get_critic_values(self, state, obs):
        """
        state torch.Size([ep_len, agent_num, share_obs_dim])
        obs torch.Size([ep_len, agent_num, obs_dim])
        obs_rep torch.Size([ep_len, agent_num, n_embd])
        action torch.Size([ep_len, agent_num, act_dim])
        """
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        # critic network should get state that processed by encoder
        # obs_rep = self.encoder(state, obs).reshape(obs.shape[0], -1)
        obs_rep = self.encoder(state, obs)
        v_tot = self.critic(obs_rep)

        return v_tot

    def get_discriminator_logit(self, state, obs, action):
        """
        state torch.Size([ep_len, agent_num, share_obs_dim])
        obs torch.Size([ep_len, agent_num, obs_dim])
        obs_rep torch.Size([ep_len, agent_num, n_embd])
        action torch.Size([ep_len, agent_num, act_dim])
        """
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        dis_logit = self.discriminator(state, obs, action)

        return dis_logit

    def get_discriminator_reward(self, state, obs, action):
        # state unused
        ori_shape = np.shape(state)
        state = np.zeros((*ori_shape[:-1], 37), dtype=np.float32)

        state = check(state).to(**self.tpdv)
        obs = check(obs).to(**self.tpdv)
        action = check(action).to(**self.tpdv)
        dis_reward = self.discriminator.get_rewards(state, obs, action)

        return dis_reward

    def get_discriminator_rewards_from_logits(self, logits):
        return self.discriminator.get_rewards_from_logits(logits)
