import torch
import torch.nn.functional as F


def forward_sampling(env, forward_net):
    for i in range(env.seq_size):
        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break
        s_sub = env.state.index_select(0, active) # [B_t, L]
        logits = forward_net(s_sub) # [B_t, 20]
        actions = env.get_actions(logits, training=False)
        with torch.no_grad():
            env.state[active, i] = actions
        env.alive[active] = (actions != 0)
        
def forward_trajectory_log_prob(env, forward_net, training=True):
    log_p = torch.zeros(env.batch_size, device=env.device)

    for i in range(env.seq_size):
        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break

        s_sub = env.state.index_select(0, active) # [B_t, L]
        logits = forward_net(s_sub) # [B_t, 20]
        actions = env.get_actions(logits, training)
        log_p[active] += F.log_softmax(logits, dim=-1).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            env.state[active, i] = actions
        env.alive[active] = (actions != 0)

    return log_p #[B]

def dual_forward_traj_log_prob(env, fnet_phi, fnet_theta, training=True):
    log_p_phi = torch.zeros(env.batch_size, device=env.device)
    log_p_theta = torch.zeros(env.batch_size, device=env.device)

    for i in range(env.seq_size):
        active = env.alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break

        s_sub = env.state.index_select(0, active) # [B_t, L]
        logits_phi = fnet_phi(s_sub) # [B_t, 20]
        logits_theta = fnet_theta(s_sub) # [B_t, 20]
        actions = env.get_actions(logits_phi, training)
        log_p_phi[active] += F.log_softmax(logits_phi, dim=-1).gather(1, actions.unsqueeze(1)).squeeze(1)
        log_p_theta[active] += F.log_softmax(logits_theta, dim=-1).gather(1, actions.unsqueeze(1)).squeeze(1)
        with torch.no_grad():
            env.state[active, i] = actions
        env.alive[active] = (actions != 0)

    return log_p_phi, log_p_theta #[B], [B]

def forward_reward(env, forward_net, logz):
    log_p = forward_trajectory_log_prob(env, forward_net, training=True)
    return log_p + logz # [B]

def dual_forward_reward(env, fnet_phi, fnet_theta, logz_phi, logz_theta):
    log_p_phi, log_p_theta = dual_forward_traj_log_prob(env, fnet_phi, fnet_theta, training=True)
    return log_p_phi + logz_phi, log_p_theta + logz_theta # [B], [B]

def marginal_log_reward(net, logz, eval_state):
    B, L = eval_state.shape
    logp_total = torch.zeros(B)
    cur_state = torch.zeros((B, L), dtype=torch.long)
    alive = torch.ones(B, dtype=torch.bool)
    for i in range(L):
        active = alive.nonzero(as_tuple=True)[0]
        if active.numel() == 0:
            break
        s_sub = cur_state.index_select(0, active)  # [B_t, L]
        a = eval_state.index_select(0, active)[:, i]  # [B_t] (0=STOP, 1..V)
        logits = net(s_sub)  # [B_t, 20]
        logp = F.log_softmax(logits, dim=-1).gather(1, a.unsqueeze(1)).squeeze(1)
        logp_total.index_add_(0, active, logp)
        cur_state[active, i] = a
        alive[active] = (a != 0)
    return logp_total + logz