import torch
import numpy as np
from torch.distributions import Categorical, Normal, MultivariateNormal
from torch.nn import functional as F


def discrete_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv,
                                available_actions=None, deterministic=False):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
    shifted_action[:, 0, 0] = 1
    output_action = torch.zeros((batch_size, n_agent, 1), dtype=torch.long)
    output_action_log = torch.zeros_like(output_action, dtype=torch.float32)
    output_action_prob = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32)

    for i in range(n_agent):
        logit = decoder(shifted_action, obs_rep, obs)[:, i, :]
        if available_actions is not None:
            logit[available_actions[:, i, :] == 0] = -1e10

        distri = Categorical(logits=logit)
        action = distri.probs.argmax(dim=-1) if deterministic else distri.sample()
        action_log = distri.log_prob(action)
        output_action[:, i, :] = action.unsqueeze(-1)
        output_action_log[:, i, :] = action_log.unsqueeze(-1)
        output_action_prob[:, i, :] = distri.probs
        if i + 1 < n_agent:
            shifted_action[:, i + 1, 1:] = F.one_hot(action, num_classes=action_dim)

    return output_action, output_action_log, output_action_prob


def discrete_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv, available_actions=None):
    one_hot_action = F.one_hot(action.squeeze(-1), num_classes=action_dim)  # (batch, n_agent, action_dim)
    shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
    shifted_action[:, 0, 0] = 1
    shifted_action[:, 1:, 1:] = one_hot_action[:, :-1, :]
    logit = decoder(shifted_action, obs_rep, obs)
    if available_actions is not None:
        logit[available_actions == 0] = -1e10

    distri = Categorical(logits=logit)
    action_log = distri.log_prob(action.squeeze(-1)).unsqueeze(-1)
    entropy = distri.entropy().unsqueeze(-1)
    return action_log, entropy


def continuous_autoregreesive_act(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv,
                                  deterministic=False):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    output_action = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32)
    output_action_log = torch.zeros_like(output_action, dtype=torch.float32)

    for i in range(n_agent):
        act_mean = decoder(shifted_action, obs_rep, obs)[:, i, :]
        action_std = torch.sigmoid(decoder.log_std) * 0.5

        # log_std = torch.zeros_like(act_mean).to(**tpdv) + decoder.log_std
        # distri = Normal(act_mean, log_std.exp())
        distri = Normal(act_mean, action_std)
        action = act_mean if deterministic else distri.sample()
        action_log = distri.log_prob(action)

        output_action[:, i, :] = action
        output_action_log[:, i, :] = action_log
        if i + 1 < n_agent:
            shifted_action[:, i + 1, :] = action

        # print("act_mean: ", act_mean)
        # print("action: ", action)

    return output_action, output_action_log, output_action_log.exp()


def continuous_parallel_act(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    shifted_action[:, 1:, :] = action[:, :-1, :]

    act_mean = decoder(shifted_action, obs_rep, obs)
    action_std = torch.sigmoid(decoder.log_std) * 0.5
    distri = Normal(act_mean, action_std)

    # log_std = torch.zeros_like(act_mean).to(**tpdv) + decoder.log_std
    # distri = Normal(act_mean, log_std.exp())

    action_log = distri.log_prob(action)
    entropy = distri.entropy()
    return action_log, entropy


###################### add for decoder structure disc
def discrete_parallel_disc(
        decoder, obs_rep, obs, actions, use_act_embd, action_embeddings, batch_size, n_agent, action_dim, tpdv):
    # one_hot_action = F.one_hot(action.squeeze(-1), num_classes=action_dim)  # (batch, n_agent, action_dim)
    if use_act_embd:
        actions = self.action_embeddings(actions.squeeze(-1).long()).float()
        shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
        shifted_action[:, 1:] = actions[:, :-1]
    else:
        actions = F.one_hot(actions.squeeze(-1).long(), num_classes=self.action_dim)  # (batch, n_agent, action_dim)
        shifted_action = torch.zeros((batch_size, n_agent, action_dim + 1)).to(**tpdv)
        shifted_action[:, 0, 0] = 1
        shifted_action[:, 1:, 1:] = actions[:, :-1, :]
    disc_value = decoder(shifted_action, obs_rep, obs)
    # if available_actions is not None:
    #     logit[available_actions == 0] = -1e10

    return disc_value


def continuous_parallel_disc(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    shifted_action[:, 1:, :] = action[:, :-1, :]
    disc_value = decoder(shifted_action, obs_rep, obs)

    return disc_value

###################### add for decoder structure disc


###################### add for get action distri(not use)
def continuous_autoregreesive_distb_gail(decoder, obs_rep, obs, batch_size, n_agent, action_dim, tpdv,
                                         deterministic=False):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    output_action = torch.zeros((batch_size, n_agent, action_dim), dtype=torch.float32)
    distri_list = []

    for i in range(n_agent):
        act_mean = decoder(shifted_action, obs_rep, obs)[:, i, :]
        action_std = torch.sigmoid(decoder.log_std) * 0.5

        # distri = Normal(act_mean, action_std)
        cov_mtx = torch.eye(action_dim).to(**tpdv) * (action_std ** 2)
        distri = MultivariateNormal(act_mean, cov_mtx)
        action = act_mean if deterministic else distri.sample()
        # action_log = distri.log_prob(action)
        distri_list.append(distri)

        output_action[:, i, :] = action
        if i + 1 < n_agent:
            shifted_action[:, i + 1, :] = action

        # print("act_mean: ", act_mean)
        # print("action: ", action)

    return distri_list


def continuous_parallel_act_distb_gail(decoder, obs_rep, obs, action, batch_size, n_agent, action_dim, tpdv):
    shifted_action = torch.zeros((batch_size, n_agent, action_dim)).to(**tpdv)
    shifted_action[:, 1:, :] = action[:, :-1, :]

    act_mean = decoder(shifted_action, obs_rep, obs)
    action_std = torch.sigmoid(decoder.log_std) * 0.5
    # distri = Normal(act_mean, action_std)
    cov_mtx = torch.eye(action_dim).to(**tpdv) * (action_std ** 2)
    distri = MultivariateNormal(act_mean, cov_mtx)

    # log_std = torch.zeros_like(act_mean).to(**tpdv) + decoder.log_std
    # distri = Normal(act_mean, log_std.exp())

    # action_log = distri.log_prob(action)
    # entropy = distri.entropy()
    return distri
###################### add for get action distri(not use)
